merge: resolve conflicts between main and litellm_oss_staging_03_11_2026

This commit is contained in:
Chesars
2026-03-12 09:38:31 -03:00
parent 8d2432b21b
commit 1be6b31e2f
205 changed files with 16204 additions and 10448 deletions
+11 -4
View File
@@ -1,12 +1,19 @@
name: "LiteLLM CodeQL config"
# Exclude queries that produce result sets > 2 GiB on this codebase,
# causing 49+ minute runs that fail and block CI resources.
# Use security-extended suite instead of security-and-quality to avoid
# result sets > 2 GiB on this codebase that cause fatal OOM failures.
queries:
- uses: security-extended
# These two queries are security queries included in security-extended that
# individually produce result sets > 2 GiB on this codebase, causing fatal
# OOM failures. Exclude them as a safety net until CI confirms they no longer
# OOM; drop these exclusions in a follow-up once verified.
query-filters:
- exclude:
id: py/clear-text-logging-sensitive-data # CWE-312/CleartextLogging.ql — result set > 2 GiB
id: py/clear-text-logging-sensitive-data # CWE-312 — > 2 GiB result set
- exclude:
id: py/polynomial-redos # CWE-730/PolynomialReDoS.ql — result set > 2 GiB
id: py/polynomial-redos # CWE-730 — > 2 GiB result set
paths-ignore:
- tests
+7
View File
@@ -110,6 +110,13 @@ LiteLLM is a unified interface for 100+ LLM providers with two main components:
### Proxy database access
- **Do not write raw SQL** for proxy DB operations. Use Prisma model methods instead of `execute_raw` / `query_raw`.
- Use the generated client: `prisma_client.db.<model>` (e.g. `litellm_tooltable`, `litellm_usertable`) with `.upsert()`, `.find_many()`, `.find_unique()`, `.update()`, `.update_many()` as appropriate. This avoids schema/client drift, keeps code testable with simple mocks, and matches patterns used in spend logs and other proxy code.
- **No N+1 queries.** Never query the DB inside a loop. Batch-fetch with `{"in": ids}` and distribute in-memory.
- **Batch writes.** Use `create_many`/`update_many`/`delete_many` instead of individual calls (these return counts only; `update_many`/`delete_many` no-op silently on missing rows). When multiple separate writes target the same table (e.g. in `batch_()`), order by primary key to avoid deadlocks.
- **Push work to the DB.** Filter, sort, group, and aggregate in SQL, not Python. Verify Prisma generates the expected SQL — e.g. prefer `group_by` over `find_many(distinct=...)` which does client-side processing.
- **Bound large result sets.** Prisma materializes full results in memory. For results over ~10 MB, paginate with `take`/`skip` or `cursor`/`take`, always with an explicit `order`. Prefer cursor-based pagination (`skip` is O(n)). Don't paginate naturally small result sets.
- **Limit fetched columns on wide tables.** Use `select` to fetch only needed fields — returns a partial object, so downstream code must not access unselected fields.
- **Check index coverage.** For new or modified queries, check `schema.prisma` for a supporting index. Prefer extending an existing index (e.g. `@@index([a])``@@index([a, b])`) over adding a new one, unless it's a `@@unique`. Only add indexes for large/frequent queries.
- **Keep schema files in sync.** Apply schema changes to all `schema.prisma` copies (`schema.prisma`, `litellm/proxy/`, `litellm-proxy-extras/`, `litellm-js/spend-logs/` for SpendLogs) with a migration under `litellm-proxy-extras/litellm_proxy_extras/migrations/`.
### Enterprise Features
- Enterprise-specific code in `enterprise/` directory
@@ -61,6 +61,20 @@ Create the name of the service account to use
{{- end }}
{{- end }}
{{/*
Create the service account name used by migration jobs.
When Helm hooks are enabled, pre-install/pre-upgrade hooks run before normal resources.
If this chart is creating the ServiceAccount, it is not yet available for the hook job,
so fall back to "default" (or an explicit override) to avoid a cyclic dependency.
*/}}
{{- define "litellm.migrationServiceAccountName" -}}
{{- if and .Values.migrationJob.hooks.helm.enabled .Values.serviceAccount.create }}
{{- default "default" .Values.migrationJob.serviceAccountName }}
{{- else }}
{{- include "litellm.serviceAccountName" . }}
{{- end }}
{{- end }}
{{/*
Get redis service name
*/}}
@@ -34,7 +34,7 @@ spec:
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "litellm.serviceAccountName" . }}
serviceAccountName: {{ include "litellm.migrationServiceAccountName" . }}
{{- with .Values.migrationJob.extraInitContainers }}
initContainers:
{{- toYaml . | nindent 8 }}
@@ -124,4 +124,67 @@ tests:
- notContains:
path: spec.template.spec.containers[0].env
content:
name: DATABASE_URL
name: DATABASE_URL
- it: should use default service account for helm hooks when serviceAccount.create is true
template: migrations-job.yaml
set:
migrationJob:
enabled: true
hooks:
helm:
enabled: true
serviceAccount:
create: true
asserts:
- equal:
path: spec.template.spec.serviceAccountName
value: default
- it: should use migrationJob.serviceAccountName override for helm hooks when serviceAccount.create is true
template: migrations-job.yaml
set:
migrationJob:
enabled: true
serviceAccountName: migration-sa
hooks:
helm:
enabled: true
serviceAccount:
create: true
asserts:
- equal:
path: spec.template.spec.serviceAccountName
value: migration-sa
- it: should use chart service account when helm hooks are disabled
template: migrations-job.yaml
set:
migrationJob:
enabled: true
hooks:
helm:
enabled: false
serviceAccount:
create: true
name: my-custom-sa
asserts:
- equal:
path: spec.template.spec.serviceAccountName
value: my-custom-sa
- it: should use pre-existing service account when helm hooks are enabled but serviceAccount.create is false
template: migrations-job.yaml
set:
migrationJob:
enabled: true
hooks:
helm:
enabled: true
serviceAccount:
create: false
name: pre-existing-sa
asserts:
- equal:
path: spec.template.spec.serviceAccountName
value: pre-existing-sa
+4
View File
@@ -309,6 +309,10 @@ migrationJob:
retries: 3 # Number of retries for the Job in case of failure
backoffLimit: 4 # Backoff limit for Job restarts
disableSchemaUpdate: false # Skip schema migrations for specific environments. When True, the job will exit with code 0.
# Optional service account for the migration job.
# Only used when migrationJob.hooks.helm.enabled=true and serviceAccount.create=true.
# In that case, pre-install/pre-upgrade hooks run before normal resources, so this defaults to "default".
serviceAccountName: ""
annotations: {}
ttlSecondsAfterFinished: 120
resources: {}
@@ -0,0 +1,169 @@
---
slug: gemini_embedding_2_multimodal
title: "Gemini Embedding 2 Preview: Multimodal Embeddings on LiteLLM"
date: 2025-03-11T10:00:00
authors:
- name: Sameer Kankute
title: SWE @ LiteLLM (LLM Translation)
url: https://www.linkedin.com/in/sameer-kankute/
image_url: https://pbs.twimg.com/profile_images/2001352686994907136/ONgNuSk5_400x400.jpg
description: "Generate embeddings from text, images, audio, video, and PDFs with gemini-embedding-2-preview on LiteLLM via Gemini API and Vertex AI."
tags: [gemini, embeddings, multimodal, vertex ai]
hide_table_of_contents: false
---
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Gemini Embedding 2 Preview: Multimodal Embeddings
LiteLLM now supports **multimodal embeddings** with `gemini-embedding-2-preview`—generating a single embedding from a mix of text, images, audio, video, and PDF content. Available via both the **Gemini API** (API key) and **Vertex AI** (GCP credentials).
## Supported Input Types
| Modality | Supported Formats |
|----------|-------------------|
| **Text** | Plain text |
| **Image** | PNG, JPEG |
| **Audio** | MP3, WAV |
| **Video** | MP4, MOV |
| **Documents** | PDF |
## Input Formats
LiteLLM accepts three input formats for multimodal content:
1. **Data URIs** Base64-encoded inline: `data:image/png;base64,<encoded_data>`
2. **GCS URLs** Cloud Storage paths (Vertex AI): `gs://bucket/path/to/file.png`
3. **Gemini File References** Pre-uploaded files (Gemini API): `files/abc123`
## Quick Start
<Tabs>
<TabItem value="gemini" label="Gemini API">
```python
from litellm import embedding
import os
os.environ["GEMINI_API_KEY"] = "your-api-key"
# Text + Image (base64)
response = embedding(
model="gemini/gemini-embedding-2-preview",
input=[
"The food was delicious and the waiter...",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII"
],
)
print(response)
```
</TabItem>
<TabItem value="vertex" label="Vertex AI">
```python
import litellm
from litellm import embedding
litellm.vertex_project = "your-project-id"
litellm.vertex_location = "us-central1"
# Text + Image (GCS URL)
response = embedding(
model="vertex_ai/gemini-embedding-2-preview",
input=[
"Describe this image",
"gs://my-bucket/images/photo.png"
],
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="LiteLLM Proxy">
**1. Config (config.yaml)**
```yaml
model_list:
- model_name: gemini-embedding-2-preview
litellm_params:
model: gemini/gemini-embedding-2-preview
api_key: os.environ/GEMINI_API_KEY
- model_name: vertex-gemini-embedding-2-preview
litellm_params:
model: vertex_ai/gemini-embedding-2-preview
vertex_project: os.environ/VERTEXAI_PROJECT
vertex_location: os.environ/VERTEXAI_LOCATION
general_settings:
master_key: sk-1234
```
**2. Start proxy**
```bash
litellm --config config.yaml
```
**3. Call embeddings**
```bash
curl -X POST http://localhost:4000/embeddings \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-embedding-2-preview",
"input": [
"The food was delicious and the waiter...",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII"
]
}'
```
</TabItem>
</Tabs>
## Input Format Examples
| Format | Example | Provider |
|--------|---------|----------|
| **Data URI** | `data:image/png;base64,...` | Gemini, Vertex AI |
| **GCS URL** | `gs://bucket/path/image.png` | Vertex AI |
| **File reference** | `files/abc123` | Gemini API only |
### Supported MIME Types for Data URIs
- **Images:** `image/png`, `image/jpeg`
- **Audio:** `audio/mpeg`, `audio/wav`
- **Video:** `video/mp4`, `video/quicktime`
- **Documents:** `application/pdf`
### GCS URL MIME Inference
For Vertex AI, MIME types are inferred from file extensions:
- `.png``image/png`
- `.jpg` / `.jpeg``image/jpeg`
- `.mp3``audio/mpeg`
- `.wav``audio/wav`
- `.mp4``video/mp4`
- `.mov``video/quicktime`
- `.pdf``application/pdf`
## Optional Parameters
| Parameter | Description | Maps to |
|-----------|-------------|---------|
| `dimensions` | Output embedding size | `outputDimensionality` |
```python
response = embedding(
model="gemini/gemini-embedding-2-preview",
input=["text to embed"],
dimensions=768, # Optional: control output vector size
)
```
@@ -138,6 +138,7 @@ The `/v1/messages/count_tokens` endpoint automatically routes to the appropriate
| Provider | Token Counting Method |
|----------|----------------------|
| Anthropic | [Anthropic Token Counting API](https://docs.anthropic.com/en/docs/build-with-claude/token-counting) |
| OpenAI | [OpenAI Responses API `/input_tokens`](https://platform.openai.com/docs/api-reference/responses/input-tokens) — see [Token Counting](./count_tokens.md) |
| Vertex AI (Claude) | Vertex AI Partner Models Token Counter |
| Bedrock (Claude) | AWS Bedrock CountTokens API |
| Gemini | Google AI Studio countTokens API |
+1
View File
@@ -11,6 +11,7 @@ This endpoint supports various guardrail types including:
- **Presidio** - PII detection and masking
- **Bedrock** - AWS Bedrock guardrails for content moderation
- **Lakera** - AI safety guardrails
- **PANW Prisma AIRS** - Threat detection, DLP, and policy enforcement
- **Custom guardrails** - User-defined guardrails
## Configuration
+2 -1
View File
@@ -13,7 +13,7 @@ import TabItem from '@theme/TabItem';
| Fallbacks | ✅ | Works between supported models |
| Loadbalancing | ✅ | Works between supported models |
| Guardrails | ✅ | Applies to output transcribed text (non-streaming only) |
| Supported Providers | `openai`, `azure`, `vertex_ai`, `gemini`, `deepgram`, `groq`, `fireworks_ai`, `ovhcloud` | |
| Supported Providers | `openai`, `azure`, `vertex_ai`, `gemini`, `deepgram`, `groq`, `fireworks_ai`, `ovhcloud`, `mistral` | |
## Quick Start
@@ -126,6 +126,7 @@ transcript = client.audio.transcriptions.create(
- [Fireworks AI](./providers/fireworks_ai.md#audio-transcription)
- [Groq](./providers/groq.md#speech-to-text---whisper)
- [Deepgram](./providers/deepgram.md)
- [Mistral (Voxtral)](./providers/mistral.md#audio-transcription)
- [OVHcloud AI Endpoints](./providers/ovhcloud.md)
---
+22
View File
@@ -51,6 +51,28 @@ Here's what an example response looks like
}
```
## Native Finish Reason
LiteLLM maps all provider-specific `finish_reason` values to OpenAI-compatible values (`stop`, `length`, `tool_calls`, `function_call`, `content_filter`). When the original provider value differs from the mapped value, it is preserved in `provider_specific_fields["native_finish_reason"]`.
This is useful for agent loops that need to distinguish between different stop conditions (e.g., Gemini's `MALFORMED_FUNCTION_CALL` vs a normal `stop`).
```python
response = completion(model="gemini/gemini-2.0-flash", messages=messages)
choice = response.choices[0]
print(choice.finish_reason) # "stop" (OpenAI-compatible)
# Access the original provider value when it differs:
if hasattr(choice, "provider_specific_fields") and choice.provider_specific_fields:
native = choice.provider_specific_fields.get("native_finish_reason")
if native == "MALFORMED_FUNCTION_CALL":
# Handle malformed function call differently from a normal stop
pass
```
When the provider already returns an OpenAI-compatible value (e.g., `stop`), `native_finish_reason` is not set.
## Additional Attributes
You can also access information like latency.
@@ -115,6 +115,11 @@ print(response)
Web fetch is available on the following Anthropic API models:
- `claude-opus-4-6` (Claude Opus 4.6)
- `claude-sonnet-4-6` (Claude Sonnet 4.6)
- `claude-opus-4-5` (Claude Opus 4.5)
- `claude-sonnet-4-5` (Claude Sonnet 4.5)
- `claude-haiku-4-5` (Claude Haiku 4.5)
- `claude-opus-4-1-20250805` (Claude Opus 4.1)
- `claude-opus-4-20250514` (Claude Opus 4)
- `claude-sonnet-4-20250514` (Claude Sonnet 4)
@@ -80,6 +80,36 @@ That's it! The provider is now available.
}
```
## Responses API Support
If your provider also supports the OpenAI Responses API (`/v1/responses`), add `supported_endpoints`:
```json
{
"your_provider": {
"base_url": "https://api.yourprovider.com/v1",
"api_key_env": "YOUR_PROVIDER_API_KEY",
"supported_endpoints": ["/v1/chat/completions", "/v1/responses"]
}
}
```
This enables `litellm.responses()` with zero additional code:
```python
import litellm
response = litellm.responses(
model="your_provider/model-name",
input="Hello, what can you do?",
)
print(response.output)
```
If `supported_endpoints` is omitted, it defaults to `[]`. Chat completions is always enabled for JSON providers regardless of this field.
The provider inherits all request/response handling from OpenAI's Responses API — streaming, tools, and all standard parameters work out of the box.
## Usage
```python
@@ -89,11 +119,17 @@ import os
# Set your API key
os.environ["YOUR_PROVIDER_API_KEY"] = "your-key-here"
# Use the provider
# Chat completions
response = litellm.completion(
model="your_provider/model-name",
messages=[{"role": "user", "content": "Hello"}],
)
# Responses API (if supported_endpoints includes "/v1/responses")
response = litellm.responses(
model="your_provider/model-name",
input="Hello",
)
```
## When to Use Python Instead
@@ -105,7 +141,9 @@ Use a Python config class if you need:
- Provider-specific streaming logic
- Advanced tool calling modifications
For these cases, create a config class in `litellm/llms/your_provider/chat/transformation.py` that inherits from `OpenAIGPTConfig` or `OpenAILikeChatConfig`.
For chat completions, create a config class in `litellm/llms/your_provider/chat/transformation.py` that inherits from `OpenAIGPTConfig` or `OpenAILikeChatConfig`.
For responses API with small overrides, inherit from `OpenAIResponsesAPIConfig` and override only what's needed. See `litellm/llms/perplexity/responses/transformation.py` for a minimal example (~40 lines vs 400+).
## Testing
+189
View File
@@ -0,0 +1,189 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Token Counting
## Overview
LiteLLM provides exact token counting by calling provider-specific token counting APIs. This gives you accurate token counts before sending requests, helping with cost estimation and context window management.
| Feature | Details |
|---------|---------|
| SDK Method | `litellm.acount_tokens()` |
| Proxy Endpoints | `/v1/messages/count_tokens` (Anthropic format), `/v1/responses/input_tokens` (OpenAI format) |
| Fallback | Local tiktoken-based counting for unsupported providers |
## Supported Providers
| Provider | Token Counting API | Format |
|----------|-------------------|--------|
| OpenAI | [Responses API `/input_tokens`](https://platform.openai.com/docs/api-reference/responses/input-tokens) | OpenAI Responses |
| Anthropic | [Messages `/count_tokens`](https://docs.anthropic.com/en/docs/build-with-claude/token-counting) | Anthropic Messages |
| Vertex AI (Claude) | Vertex AI Partner Models Token Counter | Anthropic Messages |
| Bedrock (Claude) | AWS Bedrock CountTokens API | Anthropic Messages |
| Gemini | Google AI Studio countTokens API | Anthropic Messages |
| Vertex AI (Gemini) | Vertex AI countTokens API | Anthropic Messages |
| Other providers | Local tiktoken fallback | N/A |
## SDK Usage
### Basic Usage
```python
import asyncio
import litellm
async def main():
# OpenAI
result = await litellm.acount_tokens(
model="openai/gpt-4o",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)
print(f"Token count: {result.total_tokens}")
print(f"Tokenizer: {result.tokenizer_type}") # "openai_api"
# Anthropic
result = await litellm.acount_tokens(
model="anthropic/claude-3-5-sonnet-20241022",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)
print(f"Token count: {result.total_tokens}")
print(f"Tokenizer: {result.tokenizer_type}") # "anthropic_api"
asyncio.run(main())
```
### With Tools and System Message
```python
import asyncio
import litellm
async def main():
result = await litellm.acount_tokens(
model="openai/gpt-4o",
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
tools=[{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}],
system="You are a helpful weather assistant.",
)
print(f"Token count (with tools): {result.total_tokens}")
asyncio.run(main())
```
### Response Format
`litellm.acount_tokens()` returns a `TokenCountResponse`:
```python
TokenCountResponse(
total_tokens=15, # Token count
request_model="openai/gpt-4o", # Model requested
model_used="gpt-4o", # Model used for counting
tokenizer_type="openai_api", # "openai_api", "anthropic_api", "local_tokenizer"
original_response={"input_tokens": 15}, # Raw API response
error=False, # True if counting failed
error_message=None, # Error details if failed
)
```
### Fallback Behavior
If a provider doesn't support a token counting API, or if the API key is missing, `acount_tokens()` automatically falls back to local tiktoken-based counting:
```python
# Unsupported provider → automatic fallback
result = await litellm.acount_tokens(
model="together_ai/meta-llama/Llama-3-8b-chat-hf",
messages=[{"role": "user", "content": "Hello"}],
)
print(result.tokenizer_type) # "local_tokenizer"
```
## Proxy Usage
### OpenAI Format — `/v1/responses/input_tokens`
<Tabs>
<TabItem value="curl" label="curl">
```bash
curl -X POST "http://localhost:4000/v1/responses/input_tokens" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"input": "Hello, how are you?"
}'
```
</TabItem>
<TabItem value="python" label="Python (httpx)">
```python
import httpx
response = httpx.post(
"http://localhost:4000/v1/responses/input_tokens",
headers={
"Content-Type": "application/json",
"Authorization": "Bearer sk-1234"
},
json={
"model": "gpt-4o",
"input": "Hello, how are you?"
}
)
print(response.json())
# {"input_tokens": 7}
```
</TabItem>
</Tabs>
**Response:**
```json
{"input_tokens": 7}
```
### Anthropic Format — `/v1/messages/count_tokens`
See [Anthropic Token Counting](./anthropic_count_tokens.md) for full documentation.
```bash
curl -X POST "http://localhost:4000/v1/messages/count_tokens" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "claude-3-5-sonnet-20241022",
"messages": [
{"role": "user", "content": "Hello, how are you?"}
]
}'
```
## Proxy Configuration
```yaml
model_list:
- model_name: gpt-4o
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
- model_name: claude-3-5-sonnet
litellm_params:
model: anthropic/claude-3-5-sonnet-20241022
api_key: os.environ/ANTHROPIC_API_KEY
```
@@ -514,6 +514,57 @@ All models listed [here](https://ai.google.dev/gemini-api/docs/models/gemini) ar
| Model Name | Function Call |
| :--- | :--- |
| text-embedding-004 | `embedding(model="gemini/text-embedding-004", input)` |
| gemini-embedding-2-preview | `embedding(model="gemini/gemini-embedding-2-preview", input)` | [Multimodal docs](#gemini-embedding-2-preview-multimodal) |
### Gemini Embedding 2 Preview (Multimodal)
`gemini-embedding-2-preview` supports **multimodal embeddings**—text, images, audio, video, and PDF in a single request. See [blog post](/blog/gemini_embedding_2_multimodal) for details.
**Input formats:**
- **Data URIs:** `data:image/png;base64,<encoded_data>`
- **Gemini file references:** `files/abc123` (pre-uploaded via Gemini Files API)
**Supported MIME types:** `image/png`, `image/jpeg`, `audio/mpeg`, `audio/wav`, `video/mp4`, `video/quicktime`, `application/pdf`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import embedding
import os
os.environ["GEMINI_API_KEY"] = ""
# Text + Image (base64)
response = embedding(
model="gemini/gemini-embedding-2-preview",
input=[
"The food was delicious and the waiter...",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII"
],
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```bash
curl -X POST http://localhost:4000/embeddings \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"model": "gemini-embedding-2-preview",
"input": [
"The food was delicious and the waiter...",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII"
]
}'
```
</TabItem>
</Tabs>
**Optional:** `dimensions` maps to Gemini's `outputDimensionality`.
## Vertex AI Embedding Models
+87 -1
View File
@@ -16,7 +16,7 @@ LiteLLM provides image editing functionality that maps to OpenAI's `/images/edit
| Supported operations | Create image edits | Single and multiple images supported |
| Supported LiteLLM SDK Versions | 1.63.8+ | Gemini support requires 1.79.3+ |
| Supported LiteLLM Proxy Versions | 1.71.1+ | Gemini support requires 1.79.3+ |
| Supported LLM providers | **OpenAI**, **Gemini (Google AI Studio)**, **Vertex AI**, **OpenRouter**, **Stability AI**, **AWS Bedrock (Stability)** | Gemini supports the new `gemini-2.5-flash-image` family. Vertex AI supports both Gemini and Imagen models. OpenRouter routes image edits through chat completions. Stability AI and Bedrock Stability support various image editing operations. |
| Supported LLM providers | **OpenAI**, **Gemini (Google AI Studio)**, **Vertex AI**, **OpenRouter**, **Stability AI**, **AWS Bedrock (Stability)**, **Black Forest Labs** | Gemini supports the new `gemini-2.5-flash-image` family. Vertex AI supports both Gemini and Imagen models. OpenRouter routes image edits through chat completions. Stability AI and Bedrock Stability support various image editing operations. Black Forest Labs supports FLUX Kontext models. |
#### ⚡️See all supported models and providers at [models.litellm.ai](https://models.litellm.ai/)
@@ -199,6 +199,63 @@ for idx, image_obj in enumerate(response.data):
</TabItem>
<TabItem value="bfl" label="Black Forest Labs">
#### Basic Image Edit
```python showLineNumbers title="Black Forest Labs Image Edit"
import os
import litellm
os.environ["BFL_API_KEY"] = "your-api-key"
response = litellm.image_edit(
model="black_forest_labs/flux-kontext-pro",
image=open("original_image.png", "rb"),
prompt="Add a green leaf to the scene",
)
print(response.data[0].url)
```
#### Inpainting with Mask
```python showLineNumbers title="Black Forest Labs Inpainting"
import os
import litellm
os.environ["BFL_API_KEY"] = "your-api-key"
# Use flux-pro-1.0-fill for inpainting
response = litellm.image_edit(
model="black_forest_labs/flux-pro-1.0-fill",
image=open("original_image.png", "rb"),
mask=open("mask_image.png", "rb"),
prompt="Replace with a garden",
)
print(response.data[0].url)
```
#### Outpainting (Expand)
```python showLineNumbers title="Black Forest Labs Outpainting"
import os
import litellm
os.environ["BFL_API_KEY"] = "your-api-key"
# Use flux-pro-1.0-expand to extend image borders
response = litellm.image_edit(
model="black_forest_labs/flux-pro-1.0-expand",
image=open("original_image.png", "rb"),
prompt="Continue the scene with mountains",
top=256,
bottom=256,
)
print(response.data[0].url)
```
</TabItem>
<TabItem value="vertex_ai" label="Vertex AI">
#### Basic Image Edit (Gemini)
@@ -392,6 +449,35 @@ curl -X POST "http://0.0.0.0:4000/v1/images/edits" \
</TabItem>
<TabItem value="bfl" label="Black Forest Labs">
1. Add Black Forest Labs image edit models to your `config.yaml`:
```yaml showLineNumbers title="Black Forest Labs Proxy Configuration"
model_list:
- model_name: bfl-kontext-pro
litellm_params:
model: black_forest_labs/flux-kontext-pro
api_key: os.environ/BFL_API_KEY
model_info:
mode: image_edit
```
2. Start the LiteLLM proxy server:
```bash showLineNumbers title="Start LiteLLM Proxy Server"
litellm --config /path/to/config.yaml
```
3. Make an image edit request:
```bash showLineNumbers title="Black Forest Labs Proxy Image Edit"
curl -X POST "http://0.0.0.0:4000/v1/images/edits" \
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
-F "model=bfl-kontext-pro" \
-F "image=@original_image.png" \
-F "prompt=Add a sunset in the background"
```
</TabItem>
<TabItem value="vertex_ai" label="Vertex AI">
1. Add Vertex AI image edit models to your `config.yaml`:
+1 -1
View File
@@ -15,7 +15,7 @@ import TabItem from '@theme/TabItem';
| Fallbacks | ✅ | Works between supported models |
| Loadbalancing | ✅ | Works between supported models |
| Guardrails | ✅ | Applies to input prompts (non-streaming only) |
| Supported Providers | OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock, Recraft, OpenRouter, Xinference, Nscale | |
| Supported Providers | OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock, Black Forest Labs, Recraft, OpenRouter, Xinference, Nscale | |
## Quick Start
+15
View File
@@ -133,6 +133,21 @@ LiteLLM attempts [OAuth 2.0 Authorization Server Discovery](https://datatracker.
<br/>
### AWS SigV4 Authentication
For MCP servers hosted on [AWS Bedrock AgentCore](https://docs.aws.amazon.com/bedrock/latest/userguide/agentcore.html), select **AWS SigV4** as the authentication type. LiteLLM will sign every outgoing MCP request with your AWS credentials using [Signature Version 4](https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html).
<Image
img={require('../img/mcp_aws_sigv4_ui.png')}
style={{width: '80%', display: 'block', margin: '0'}}
/>
Fill in your AWS region, service name (defaults to `bedrock-agentcore`), and optionally your AWS access key and secret. If credentials are omitted, LiteLLM falls back to the boto3 credential chain (IAM roles, environment variables, etc.).
[**See full SigV4 setup guide**](./mcp_aws_sigv4.md)
<br/>
### Static Headers
Sometimes your MCP server needs specific headers on every request. Maybe it's an API key, maybe it's a custom header the server expects. Instead of configuring auth, you can just set them directly.
+39 -2
View File
@@ -1,3 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
import Image from '@theme/IdealImage';
# MCP - AWS SigV4 Auth
Use AWS SigV4 authentication to connect LiteLLM to MCP servers hosted on [AWS Bedrock AgentCore](https://docs.aws.amazon.com/bedrock/latest/userguide/agentcore.html).
@@ -10,6 +14,36 @@ LiteLLM's `aws_sigv4` auth type handles this automatically: every outgoing MCP r
## Quick Start
<Tabs>
<TabItem value="ui" label="LiteLLM UI">
1. Navigate to **MCP Servers** and click **Add New MCP Server**
2. Set the transport to **Streamable HTTP**
3. Select **AWS SigV4** as the authentication type
4. Fill in your AWS credentials:
<Image
img={require('../img/mcp_aws_sigv4_ui.png')}
style={{width: '80%', display: 'block', margin: '0'}}
/>
<br/>
| Field | Required | Description |
|-------|----------|-------------|
| **AWS Region** | Yes | AWS region for SigV4 signing (e.g., `us-east-1`) |
| **AWS Service Name** | No | Defaults to `bedrock-agentcore` |
| **AWS Access Key ID** | No | Falls back to boto3 credential chain if blank |
| **AWS Secret Access Key** | No | Required if Access Key ID is provided |
| **AWS Session Token** | No | Only needed for temporary STS credentials |
Once created, LiteLLM will sign every outgoing MCP request with SigV4. The server's tools appear automatically in the MCP Tools list.
**Editing credentials:** When editing an existing SigV4 server, leave credential fields blank to keep the current values. Only fields you fill in will be updated.
</TabItem>
<TabItem value="config" label="config.yaml">
### 1. Set AWS credentials
```bash
@@ -60,9 +94,12 @@ arn%3Aaws%3Abedrock-agentcore%3Aus-east-1%3A123456789012%3Aruntime%2Fmy-mcp-serv
litellm --config config.yaml
```
### 4. Use the MCP tools
</TabItem>
</Tabs>
Once started, your AgentCore MCP tools are available through LiteLLM like any other MCP server:
## Use the MCP tools
Once configured, your AgentCore MCP tools are available through LiteLLM like any other MCP server:
```bash title="List available tools"
curl http://localhost:4000/mcp-rest/tools/list \
+1
View File
@@ -86,4 +86,5 @@ MCP guardrails work with all LiteLLM-supported guardrail providers:
- **Lakera**: Content moderation
- **Aporia**: Custom guardrails
- **Noma**: Noma Security
- **PANW Prisma AIRS**: Prisma AIRS guardrails
- **Custom**: Your own guardrail implementations
@@ -13,6 +13,7 @@ Here's the full specification with all available fields:
```json
{
"sample_spec": {
"aliases": ["optional list of alternate names for this model, e.g. dated versions like sample_spec-20250101"],
"code_interpreter_cost_per_session": 0.0,
"computer_use_input_cost_per_1k_tokens": 0.0,
"computer_use_output_cost_per_1k_tokens": 0.0,
@@ -121,4 +122,28 @@ Here's the full specification with all available fields:
}
```
That's it! Your PR will be reviewed and merged.
### Using Aliases
Many providers release the same model under multiple names — for example, a `latest` tag and a dated version like `claude-sonnet-4-5-20250929`. Instead of duplicating the entire entry, you can use the `aliases` field:
```json
{
"claude-sonnet-4-5": {
"aliases": ["claude-sonnet-4-5-20250929"],
"input_cost_per_token": 3e-06,
"output_cost_per_token": 1.5e-05,
"litellm_provider": "anthropic",
"max_input_tokens": 200000,
"max_output_tokens": 64000,
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true
}
}
```
At load time, each alias is expanded into a top-level entry sharing the same data as the canonical entry. The example above makes both `claude-sonnet-4-5` and `claude-sonnet-4-5-20250929` resolve with the same pricing and capabilities.
:::info
This is different from [`model_alias_map`](../completion/model_alias.md), which is a runtime SDK/proxy feature for mapping user-facing model names to LiteLLM model identifiers. The `aliases` field here is for the model cost JSON only — it avoids duplicate entries for models that share identical pricing and capabilities.
:::
@@ -0,0 +1,291 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Black Forest Labs Image Generation
Black Forest Labs provides state-of-the-art text-to-image generation using their FLUX models.
## Overview
| Property | Details |
|----------|---------|
| Description | Black Forest Labs FLUX models for high-quality text-to-image generation |
| Provider Route on LiteLLM | `black_forest_labs/` |
| Provider Doc | [Black Forest Labs API ↗](https://docs.bfl.ai/) |
| Supported Operations | [`/images/generations`](#image-generation) |
## Setup
### API Key
```python showLineNumbers
import os
# Set your Black Forest Labs API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
```
Get your API key from [Black Forest Labs](https://blackforestlabs.ai/).
## Supported Models
| Model Name | Description | Price |
|------------|-------------|-------|
| `black_forest_labs/flux-pro-1.1` | Fast & reliable standard generation | $0.04/image |
| `black_forest_labs/flux-pro-1.1-ultra` | Ultra high-resolution (up to 4MP) | $0.06/image |
| `black_forest_labs/flux-dev` | Development/open-source variant | $0.025/image |
| `black_forest_labs/flux-pro` | Original pro model | $0.05/image |
## Image Generation
### Usage - LiteLLM Python SDK
<Tabs>
<TabItem value="basic" label="Basic Usage">
```python showLineNumbers title="Basic Image Generation"
import os
import litellm
# Set your API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
# Generate an image
response = litellm.image_generation(
model="black_forest_labs/flux-pro-1.1",
prompt="A beautiful sunset over the ocean with sailing boats",
)
# BFL returns URLs
print(response.data[0].url)
```
</TabItem>
<TabItem value="async" label="Async Usage">
```python showLineNumbers title="Async Image Generation"
import os
import asyncio
import litellm
# Set your API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
async def generate_image():
response = await litellm.aimage_generation(
model="black_forest_labs/flux-pro-1.1",
prompt="A futuristic city skyline at night",
)
print(response.data[0].url)
# Run the async function
asyncio.run(generate_image())
```
</TabItem>
<TabItem value="size" label="Custom Size">
```python showLineNumbers title="Image Generation with Custom Size"
import os
import litellm
# Set your API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
# Generate with specific dimensions
response = litellm.image_generation(
model="black_forest_labs/flux-pro-1.1",
prompt="A majestic mountain landscape",
size="1792x1024", # Maps to width/height
)
print(response.data[0].url)
```
</TabItem>
<TabItem value="ultra" label="Ultra High-Res">
```python showLineNumbers title="Ultra High Resolution with flux-pro-1.1-ultra"
import os
import litellm
# Set your API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
# Generate ultra high-resolution image
response = litellm.image_generation(
model="black_forest_labs/flux-pro-1.1-ultra",
prompt="Detailed portrait of a fantasy character",
size="2048x2048", # Up to 4MP supported
quality="hd", # Maps to raw=True for natural look
)
print(response.data[0].url)
```
</TabItem>
<TabItem value="advanced" label="Advanced Parameters">
```python showLineNumbers title="Advanced Image Generation with BFL Parameters"
import os
import litellm
# Set your API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
# Generate with BFL-specific parameters
response = litellm.image_generation(
model="black_forest_labs/flux-pro-1.1",
prompt="A cute orange cat sitting on a windowsill",
seed=42, # For reproducible results
output_format="png", # png or jpeg
safety_tolerance=2, # 0-6, higher = more permissive
prompt_upsampling=True, # Enhance prompt for better results
)
print(response.data[0].url)
```
</TabItem>
</Tabs>
### Usage - LiteLLM Proxy Server
#### 1. Configure your config.yaml
```yaml showLineNumbers title="Black Forest Labs Image Generation Configuration"
model_list:
- model_name: flux-pro
litellm_params:
model: black_forest_labs/flux-pro-1.1
api_key: os.environ/BFL_API_KEY
model_info:
mode: image_generation
- model_name: flux-ultra
litellm_params:
model: black_forest_labs/flux-pro-1.1-ultra
api_key: os.environ/BFL_API_KEY
model_info:
mode: image_generation
- model_name: flux-dev
litellm_params:
model: black_forest_labs/flux-dev
api_key: os.environ/BFL_API_KEY
model_info:
mode: image_generation
general_settings:
master_key: sk-1234
```
#### 2. Start LiteLLM Proxy Server
```bash showLineNumbers title="Start LiteLLM Proxy Server"
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
#### 3. Make image generation requests
<Tabs>
<TabItem value="openai-sdk" label="OpenAI SDK">
```python showLineNumbers title="Black Forest Labs via Proxy - OpenAI SDK"
from openai import OpenAI
# Initialize client with your proxy URL
client = OpenAI(
base_url="http://localhost:4000",
api_key="sk-1234"
)
# Generate image with FLUX Pro
response = client.images.generate(
model="flux-pro",
prompt="A beautiful garden with colorful flowers",
size="1024x1024",
)
print(response.data[0].url)
```
</TabItem>
<TabItem value="curl" label="cURL">
```bash showLineNumbers title="Black Forest Labs via Proxy - cURL"
curl -X POST 'http://localhost:4000/v1/images/generations' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "flux-pro",
"prompt": "A beautiful garden with colorful flowers",
"size": "1024x1024"
}'
```
</TabItem>
</Tabs>
## Supported Parameters
### OpenAI-Compatible Parameters
| Parameter | Type | Description | Mapping |
|-----------|------|-------------|---------|
| `prompt` | string | Text description of the image to generate | Direct |
| `model` | string | The FLUX model to use | Direct |
| `size` | string | Image dimensions (e.g., `1024x1024`) | Maps to `width` and `height` |
| `n` | integer | Number of images (ultra model only, up to 4) | Maps to `num_images` |
| `quality` | string | `hd` for natural look | Maps to `raw=True` for ultra |
| `response_format` | string | `url` or `b64_json` | Direct |
### Black Forest Labs Specific Parameters
| Parameter | Type | Description | Default |
|-----------|------|-------------|---------|
| `width` | integer | Image width (256-1920, multiples of 16) | 1024 |
| `height` | integer | Image height (256-1920, multiples of 16) | 1024 |
| `aspect_ratio` | string | Alternative to width/height (e.g., `16:9`, `1:1`) | - |
| `seed` | integer | Seed for reproducible results | Random |
| `output_format` | string | Output format: `png` or `jpeg` | `png` |
| `safety_tolerance` | integer | Safety filter tolerance (0-6, higher = more permissive) | 2 |
| `prompt_upsampling` | boolean | Enhance prompt for better results | `false` |
### Ultra Model Specific Parameters
| Parameter | Type | Description | Default |
|-----------|------|-------------|---------|
| `raw` | boolean | Raw mode for more natural, less synthetic look | `false` |
| `num_images` | integer | Number of images to generate (1-4) | 1 |
## How It Works
Black Forest Labs uses a polling-based API:
1. **Submit Request**: LiteLLM sends your prompt to BFL
2. **Get Task ID**: BFL returns a task ID and polling URL
3. **Poll for Result**: LiteLLM automatically polls until the image is ready
4. **Return Result**: The generated image URL is returned
This polling is handled automatically by LiteLLM - you just call `image_generation()` and get the result.
## Getting Started
1. Create an account at [Black Forest Labs](https://blackforestlabs.ai/)
2. Get your API key from the dashboard
3. Set your `BFL_API_KEY` environment variable
4. Use `litellm.image_generation()` with any supported model
## Additional Resources
- [Black Forest Labs Documentation](https://docs.bfl.ai/)
- [Black Forest Labs Image Editing](./black_forest_labs_img_edit.md) - For editing existing images
- [FLUX Model Information](https://blackforestlabs.ai/)
@@ -0,0 +1,301 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Black Forest Labs Image Editing
Black Forest Labs provides powerful image editing capabilities using their FLUX models to modify existing images based on text descriptions.
## Overview
| Property | Details |
|----------|---------|
| Description | Black Forest Labs Image Editing uses FLUX Kontext and other models to modify, inpaint, and expand images based on text prompts. |
| Provider Route on LiteLLM | `black_forest_labs/` |
| Provider Doc | [Black Forest Labs API ↗](https://docs.bfl.ai/) |
| Supported Operations | [`/images/edits`](#image-editing) |
## Setup
### API Key
```python showLineNumbers
import os
# Set your Black Forest Labs API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
```
Get your API key from [Black Forest Labs](https://blackforestlabs.ai/).
## Supported Models
| Model Name | Description | Use Case |
|------------|-------------|----------|
| `black_forest_labs/flux-kontext-pro` | FLUX Kontext Pro - General image editing with prompts | General editing, style transfer |
| `black_forest_labs/flux-kontext-max` | FLUX Kontext Max - Premium quality editing | High-quality edits |
| `black_forest_labs/flux-pro-1.0-fill` | FLUX Pro Fill - Inpainting with mask | Remove/replace objects |
| `black_forest_labs/flux-pro-1.0-expand` | FLUX Pro Expand - Outpainting | Expand image borders |
## Image Editing
### Usage - LiteLLM Python SDK
<Tabs>
<TabItem value="basic-edit" label="Basic Usage">
```python showLineNumbers title="Basic Image Editing"
import os
import litellm
# Set your API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
# Edit an image with a prompt
response = litellm.image_edit(
model="black_forest_labs/flux-kontext-pro",
image=open("path/to/your/image.png", "rb"),
prompt="Add a green leaf to the scene",
)
# BFL returns URLs
print(response.data[0].url)
```
</TabItem>
<TabItem value="async-edit" label="Async Usage">
```python showLineNumbers title="Async Image Editing"
import os
import asyncio
import litellm
# Set your API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
async def edit_image():
response = await litellm.aimage_edit(
model="black_forest_labs/flux-kontext-pro",
image=open("path/to/your/image.png", "rb"),
prompt="Make this image look like a watercolor painting",
)
print(response.data[0].url)
# Run the async function
asyncio.run(edit_image())
```
</TabItem>
<TabItem value="inpainting" label="Inpainting (Fill)">
```python showLineNumbers title="Inpainting with Mask"
import os
import litellm
# Set your API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
# Use flux-pro-1.0-fill for inpainting
response = litellm.image_edit(
model="black_forest_labs/flux-pro-1.0-fill",
image=open("path/to/your/image.png", "rb"),
mask=open("path/to/mask.png", "rb"), # White areas will be edited
prompt="Replace with a beautiful garden",
steps=50, # BFL-specific parameter
guidance=30, # BFL-specific parameter
)
print(response.data[0].url)
```
</TabItem>
<TabItem value="outpainting" label="Outpainting (Expand)">
```python showLineNumbers title="Outpainting - Expand Image Borders"
import os
import litellm
# Set your API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
# Use flux-pro-1.0-expand to extend image borders
response = litellm.image_edit(
model="black_forest_labs/flux-pro-1.0-expand",
image=open("path/to/your/image.png", "rb"),
prompt="Continue the scene with a mountain landscape",
top=256, # Expand 256 pixels at top
bottom=256, # Expand 256 pixels at bottom
left=128, # Expand 128 pixels at left
right=128, # Expand 128 pixels at right
)
print(response.data[0].url)
```
</TabItem>
<TabItem value="advanced" label="Advanced Parameters">
```python showLineNumbers title="Advanced Image Editing with BFL Parameters"
import os
import litellm
# Set your API key
os.environ["BFL_API_KEY"] = "your-api-key-here"
# Edit image with BFL-specific parameters
response = litellm.image_edit(
model="black_forest_labs/flux-kontext-pro",
image=open("path/to/your/image.png", "rb"),
prompt="Transform into cyberpunk style with neon lights",
seed=42, # For reproducible results
output_format="png", # png or jpeg
safety_tolerance=2, # 0-6, higher = more permissive
aspect_ratio="16:9", # Output aspect ratio
)
print(response.data[0].url)
```
</TabItem>
</Tabs>
### Usage - LiteLLM Proxy Server
#### 1. Configure your config.yaml
```yaml showLineNumbers title="Black Forest Labs Image Editing Configuration"
model_list:
- model_name: bfl-kontext-pro
litellm_params:
model: black_forest_labs/flux-kontext-pro
api_key: os.environ/BFL_API_KEY
model_info:
mode: image_edit
- model_name: bfl-kontext-max
litellm_params:
model: black_forest_labs/flux-kontext-max
api_key: os.environ/BFL_API_KEY
model_info:
mode: image_edit
- model_name: bfl-fill
litellm_params:
model: black_forest_labs/flux-pro-1.0-fill
api_key: os.environ/BFL_API_KEY
model_info:
mode: image_edit
- model_name: bfl-expand
litellm_params:
model: black_forest_labs/flux-pro-1.0-expand
api_key: os.environ/BFL_API_KEY
model_info:
mode: image_edit
general_settings:
master_key: sk-1234
```
#### 2. Start LiteLLM Proxy Server
```bash showLineNumbers title="Start LiteLLM Proxy Server"
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
#### 3. Make image editing requests
<Tabs>
<TabItem value="openai-sdk" label="OpenAI SDK">
```python showLineNumbers title="Black Forest Labs via Proxy - OpenAI SDK"
from openai import OpenAI
# Initialize client with your proxy URL
client = OpenAI(
base_url="http://localhost:4000",
api_key="sk-1234"
)
# Edit image with FLUX Kontext Pro
response = client.images.edit(
model="bfl-kontext-pro",
image=open("path/to/your/image.png", "rb"),
prompt="Add magical sparkles and fairy dust",
)
print(response.data[0].url)
```
</TabItem>
<TabItem value="curl" label="cURL">
```bash showLineNumbers title="Black Forest Labs via Proxy - cURL"
curl --location 'http://localhost:4000/v1/images/edits' \
--header 'Authorization: Bearer sk-1234' \
--form 'model="bfl-kontext-pro"' \
--form 'prompt="Add a sunset in the background"' \
--form 'image=@"path/to/your/image.png"'
```
</TabItem>
</Tabs>
## Supported Parameters
### OpenAI-Compatible Parameters
| Parameter | Type | Description | Default |
|-----------|------|-------------|---------|
| `image` | file | The image file to edit | Required |
| `prompt` | string | Text description of the desired changes | Required |
| `model` | string | The FLUX model to use | Required |
| `mask` | file | Mask image for inpainting (flux-pro-1.0-fill) | Optional |
| `n` | integer | Number of images (BFL returns 1 per request) | `1` |
| `size` | string | Maps to aspect_ratio | Optional |
| `response_format` | string | `url` or `b64_json` | `url` |
### Black Forest Labs Specific Parameters
| Parameter | Type | Description | Default | Models |
|-----------|------|-------------|---------|--------|
| `seed` | integer | Seed for reproducible results | Random | All |
| `output_format` | string | Output format: `png` or `jpeg` | `png` | All |
| `safety_tolerance` | integer | Safety filter tolerance (0-6) | 2 | All |
| `aspect_ratio` | string | Output aspect ratio (e.g., `16:9`, `1:1`) | Original | Kontext models |
| `steps` | integer | Number of inference steps | Model default | Fill |
| `guidance` | float | Guidance scale | Model default | Fill |
| `grow_mask` | integer | Pixels to grow mask | 0 | Fill |
| `top` | integer | Pixels to expand at top | 0 | Expand |
| `bottom` | integer | Pixels to expand at bottom | 0 | Expand |
| `left` | integer | Pixels to expand at left | 0 | Expand |
| `right` | integer | Pixels to expand at right | 0 | Expand |
## How It Works
Black Forest Labs uses a polling-based API:
1. **Submit Request**: LiteLLM sends your image and prompt to BFL
2. **Get Task ID**: BFL returns a task ID and polling URL
3. **Poll for Result**: LiteLLM automatically polls until the image is ready
4. **Return Result**: The generated image URL is returned
This polling is handled automatically by LiteLLM - you just call `image_edit()` and get the result.
## Getting Started
1. Create an account at [Black Forest Labs](https://blackforestlabs.ai/)
2. Get your API key from the dashboard
3. Set your `BFL_API_KEY` environment variable
4. Use `litellm.image_edit()` with any supported model
## Additional Resources
- [Black Forest Labs Documentation](https://docs.bfl.ai/)
- [FLUX Model Information](https://blackforestlabs.ai/)
+15 -7
View File
@@ -1562,13 +1562,18 @@ LiteLLM Supports the following image types passed in `url`
## Media Resolution Control (Images & Videos)
For Gemini 3+ models, LiteLLM supports per-part media resolution control using OpenAI's `detail` parameter. This allows you to specify different resolution levels for individual images and videos in your request, whether using `image_url` or `file` content types.
LiteLLM supports OpenAI's `detail` parameter for specifying the image resolution when using Gemini models. The behavior differs between Gemini versions:
| Gemini Version | Resolution Control | Behavior |
|----------------|-------------------|----------|
| Gemini 3+ | Per-part | Each image/video can have its own `detail` setting |
| Gemini 2.x (2.0, 2.5) | Global | The highest `detail` from all images is applied globally via `mediaResolution` in `generationConfig` |
**Supported `detail` values:**
- `"low"` - Maps to `media_resolution: "low"` (280 tokens for images, 70 tokens per frame for videos)
- `"medium"` - Maps to `media_resolution: "medium"`
- `"high"` - Maps to `media_resolution: "high"` (1120 tokens for images)
- `"ultra_high"` - Maps to `media_resolution: "ultra_high"`
- `"low"` - Maps to `MEDIA_RESOLUTION_LOW` (280 tokens for images, 70 tokens per frame for videos)
- `"medium"` - Maps to `MEDIA_RESOLUTION_MEDIUM`
- `"high"` - Maps to `MEDIA_RESOLUTION_HIGH` (1120 tokens for images)
- `"ultra_high"` - Maps to `MEDIA_RESOLUTION_ULTRA_HIGH`
- `"auto"` or `None` - Model decides optimal resolution (no `media_resolution` set)
**Usage Examples:**
@@ -1605,8 +1610,9 @@ messages = [
}
]
# Works with both Gemini 2.x and 3+
response = completion(
model="gemini/gemini-3-pro-preview",
model="gemini/gemini-2.5-flash", # or gemini-3-pro-preview
messages=messages,
)
```
@@ -1647,7 +1653,9 @@ response = completion(
</Tabs>
:::info
**Per-Part Resolution:** Each image or video in your request can have its own `detail` setting, allowing mixed-resolution requests (e.g., a high-res chart alongside a low-res icon). This feature works with both `image_url` and `file` content types, and is only available for Gemini 3+ models.
**Gemini 3+ Per-Part Resolution:** Each image or video can have its own `detail` setting, allowing mixed-resolution requests (e.g., a high-res chart alongside a low-res icon). This works with both `image_url` and `file` content types.
**Gemini 2.x Global Resolution:** When multiple images have different `detail` values, LiteLLM uses the highest resolution found and applies it globally via `mediaResolution` in `generationConfig` (e.g., if one image has `"low"` and another has `"high"`, all images will use `"high"`).
:::
## Video Metadata Control
+73
View File
@@ -311,6 +311,79 @@ print(response)
- **Model Compatibility**: Reasoning parameters only work with magistral models
- **Backward Compatibility**: Non-magistral models will ignore reasoning parameters and work normally
## Audio Transcription
Use Mistral's Voxtral models for audio transcription via `litellm.transcription()`.
### SDK Usage
```python
from litellm import transcription
import os
os.environ["MISTRAL_API_KEY"] = ""
audio_file = open("path/to/audio.wav", "rb")
response = transcription(
model="mistral/voxtral-mini-latest",
file=audio_file,
)
print(response.text)
```
### With Optional Parameters
```python
response = transcription(
model="mistral/voxtral-mini-latest",
file=audio_file,
language="en",
temperature=0.0,
response_format="json",
)
```
### Mistral-Specific Parameters
Mistral supports additional parameters beyond the OpenAI-compatible ones:
| Parameter | Type | Description |
|-----------|------|-------------|
| `diarize` | `bool` | Enable speaker diarization |
```python
response = transcription(
model="mistral/voxtral-mini-latest",
file=audio_file,
diarize=True,
)
```
### Usage with LiteLLM Proxy
```yaml
model_list:
- model_name: voxtral
litellm_params:
model: mistral/voxtral-mini-latest
api_key: os.environ/MISTRAL_API_KEY
model_info:
mode: audio_transcription
```
```bash
litellm --config /path/to/config.yaml
```
```bash
curl --location 'http://0.0.0.0:4000/v1/audio/transcriptions' \
--header 'Authorization: Bearer sk-1234' \
--form 'file=@"audio.wav"' \
--form 'model="voxtral"'
```
## Sample Usage - Embedding
```python
from litellm import embedding
+79 -9
View File
@@ -632,7 +632,9 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
## OpenAI Chat Completion to Responses API Bridge
Call any Responses API model from OpenAI's `/chat/completions` endpoint.
LiteLLM offers a chat completion to Responses API bridge. This lets you use the completion interface while calling the Responses API under the hood.
This is useful when you want to use [Responses API](https://platform.openai.com/docs/api-reference/responses) specific features (like built-in tools, web search preview, or code interpreter).
:::tip gpt-5.4 + reasoning_effort + function tools
@@ -649,12 +651,54 @@ response = litellm.completion(
:::
### When to use the `openai/responses/` prefix
Each model has a `mode` property defined in [`model_prices_and_context_window.json`](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) that determines which API endpoint it uses by default:
- **`mode: responses`** - Model automatically uses the Responses API
- **`mode: chat`** - Model defaults to the Chat Completions API
**Models with `mode: responses`** (automatic Responses API):
- `o3-deep-research`, `o4-mini-deep-research`
- `o1-pro`, `o3-pro`
- `gpt-5.1-codex`, `gpt-5.1-codex-mini`, `gpt-5.1-codex-max`
- `codex-mini-latest`
**Models with `mode: chat`** (require `openai/responses/` prefix for built-in tools):
- `gpt-4o`, `gpt-4o-mini`, `gpt-4.1`, `gpt-4.1-mini`
- `gpt-5`, `gpt-5-mini`
- `o3`, `o4-mini`
To use built-in tools like `web_search_preview` with `mode: chat` models, add the `openai/responses/` prefix:
```python
# This will FAIL - gpt-4o has mode: chat, uses Chat Completions API
response = litellm.completion(
model="gpt-4o",
messages=[{"role": "user", "content": "What is the weather in Paris today?"}],
tools=[{"type": "web_search_preview"}], # Not supported in Chat Completions
# ... other kwargs
)
# This will WORK - prefix forces Responses API
response = litellm.completion(
model="openai/responses/gpt-4o",
messages=[{"role": "user", "content": "What is the weather in Paris today?"}],
tools=[{"type": "web_search_preview"}], # Supported in Responses API
# ... other kwargs
)
```
### Examples
<Tabs>
<TabItem value="sdk" label="SDK">
**Using a model with `mode: responses` (automatic):**
```python
import litellm
import os
import os
os.environ["OPENAI_API_KEY"] = "sk-1234"
@@ -668,6 +712,26 @@ response = litellm.completion(
)
print(response)
```
**Using a model with `mode: chat` (requires prefix):**
```python
import litellm
import os
os.environ["OPENAI_API_KEY"] = "sk-1234"
# Use the openai/responses/ prefix to enable built-in tools
response = litellm.completion(
model="openai/responses/gpt-4o",
messages=[{"role": "user", "content": "What is the weather in Paris today?"}],
tools=[
{"type": "web_search_preview"},
],
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
@@ -675,10 +739,17 @@ print(response)
```yaml
model_list:
- model_name: openai-model
# Model with mode: responses (automatic)
- model_name: o3-deep-research
litellm_params:
model: o3-deep-research-2025-06-26
api_key: os.environ/OPENAI_API_KEY
# Model with mode: chat (use prefix for built-in tools)
- model_name: gpt-4o-with-tools
litellm_params:
model: openai/responses/gpt-4o
api_key: os.environ/OPENAI_API_KEY
```
2. Start the proxy
@@ -693,15 +764,14 @@ litellm --config config.yaml
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "openai-model",
-d '{
"model": "gpt-4o-with-tools",
"messages": [
{"role": "user", "content": "What is the capital of France?"}
{"role": "user", "content": "What is the weather in Paris today?"}
],
"tools": [
{"type": "web_search_preview"},
{"type": "code_interpreter", "container": {"type": "auto"}},
],
{"type": "web_search_preview"}
]
}'
```
@@ -79,6 +79,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
| gemini-embedding-2-preview | `embedding(model="vertex_ai/gemini-embedding-2-preview", input)` | [Multimodal docs](#gemini-embedding-2-preview-multimodal) |
| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/<your-model-id>", input)` |
### Supported OpenAI (Unified) Params
@@ -257,6 +258,71 @@ model_list:
## **Multi-Modal Embeddings**
### Gemini Embedding 2 Preview (Multimodal)
`gemini-embedding-2-preview` supports **unified multimodal embeddings**—text, images, audio, video, and PDF in a single request. See [blog post](/blog/gemini_embedding_2_multimodal) for details.
**Input formats:**
- **Data URIs:** `data:image/png;base64,<encoded_data>`
- **GCS URLs:** `gs://bucket/path/to/file.png` (MIME type inferred from extension)
**Supported MIME types:** `image/png`, `image/jpeg`, `audio/mpeg`, `audio/wav`, `video/mp4`, `video/quicktime`, `application/pdf`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import litellm
from litellm import embedding
litellm.vertex_project = "your-project-id"
litellm.vertex_location = "us-central1"
# Text + Image (GCS URL)
response = embedding(
model="vertex_ai/gemini-embedding-2-preview",
input=[
"Describe this image",
"gs://my-bucket/images/photo.png"
],
)
# Text + Image (base64)
response = embedding(
model="vertex_ai/gemini-embedding-2-preview",
input=[
"The food was delicious",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII"
],
)
```
</TabItem>
<TabItem value="proxy" label="LiteLLM PROXY">
```yaml
model_list:
- model_name: vertex-gemini-embedding-2-preview
litellm_params:
model: vertex_ai/gemini-embedding-2-preview
vertex_project: "your-project-id"
vertex_location: "us-central1"
```
```bash
curl -X POST http://localhost:4000/embeddings \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"model": "vertex-gemini-embedding-2-preview",
"input": ["Describe this", "gs://bucket/image.png"]
}'
```
</TabItem>
</Tabs>
### multimodalembedding@001 (Legacy)
Known Limitations:
- Only supports 1 image / video / image per request
@@ -1,24 +1,15 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# PANW Prisma AIRS
LiteLLM supports PANW Prisma AIRS (AI Runtime Security) guardrails via the [Prisma AIRS Scan API](https://pan.dev/prisma-airs/api/airuntimesecurity/airuntimesecurityapi//). This integration provides **Security-as-Code** for AI applications using Palo Alto Networks' AI security platform.
LiteLLM supports PANW Prisma AIRS (AI Runtime Security) guardrails via the [Prisma AIRS Scan API](https://pan.dev/prisma-airs/api/airuntimesecurity/airuntimesecurityapi/). This integration provides Security-as-Code for AI applications using Palo Alto Networks' AI security platform.
## Features
- **Prompt injection and malicious URL detection** — real-time scanning before or after LLM calls
- **Data loss prevention (DLP)** — detect and block sensitive data in prompts and responses
- **Sensitive content masking** — automatically mask PII, credit cards, SSNs instead of blocking
- **MCP tool call scanning** — scan tool name and arguments on direct MCP tool invocations
- **Configurable fail-open / fail-closed** — choose between maximum security or high availability
- ✅ **Real-time prompt injection detection**
- ✅ **Malicious URL detection**
- ✅ **Data loss prevention (DLP)**
- ✅ **Sensitive content masking** - Automatically mask PII, credit cards, SSNs instead of blocking
- ✅ **Comprehensive threat detection** for AI models and datasets
- ✅ **Model-agnostic protection** across public and private models
- ✅ **Synchronous scanning** with immediate response
- ✅ **Configurable security profiles**
- ✅ **Streaming support** - Real-time masking for streaming responses
- ✅ **Multi-turn conversation tracking** - Automatic session grouping in Prisma AIRS SCM logs
- ✅ **Configurable fail-open/fail-closed** - Choose between maximum security (block on API errors) or high availability (allow on transient errors)
## Quick Start
@@ -32,7 +23,14 @@ For detailed setup instructions, see the [Prisma AIRS API Overview](https://docs
### 2. Define Guardrails on your LiteLLM config.yaml
Define your guardrails under the `guardrails` section:
Set `api_base` to the regional endpoint for your Prisma AIRS deployment profile:
| Region | Endpoint |
|--------|----------|
| US | `https://service.api.aisecurity.paloaltonetworks.com` |
| EU (Germany) | `https://service-de.api.aisecurity.paloaltonetworks.com` |
| India | `https://service-in.api.aisecurity.paloaltonetworks.com` |
| Singapore | `https://service-sg.api.aisecurity.paloaltonetworks.com` |
```yaml
model_list:
@@ -45,21 +43,15 @@ guardrails:
- guardrail_name: "panw-prisma-airs-guardrail"
litellm_params:
guardrail: panw_prisma_airs
mode: "pre_call" # Run before LLM call
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY # Your Prisma AIRS API key
profile_name: os.environ/PANW_PRISMA_AIRS_PROFILE_NAME # Security profile from Strata Cloud Manager
api_base: "https://service.api.aisecurity.paloaltonetworks.com"
mode: "pre_call"
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
profile_name: os.environ/PANW_PRISMA_AIRS_PROFILE_NAME
api_base: "https://service.api.aisecurity.paloaltonetworks.com" # US — change to your region
```
#### Supported values for `mode`
- `pre_call` Run **before** LLM call, on **input**
- `post_call` Run **after** LLM call, on **input & output**
- `during_call` Run **during** LLM call, on **input**. Same as `pre_call` but runs in parallel with LLM call
### 3. Start LiteLLM Gateway
```bash title="Set environment variables"
```bash
export PANW_PRISMA_AIRS_API_KEY="your-panw-api-key"
export PANW_PRISMA_AIRS_PROFILE_NAME="your-security-profile"
export OPENAI_API_KEY="sk-proj-..."
@@ -69,15 +61,8 @@ export OPENAI_API_KEY="sk-proj-..."
litellm --config config.yaml --detailed_debug
```
### 4. Test Request
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
<Tabs>
<TabItem label="Blocked request" value="blocked">
Expect this to fail due to prompt injection attempt:
```shell
curl -i http://localhost:4000/v1/chat/completions \
@@ -92,254 +77,57 @@ curl -i http://localhost:4000/v1/chat/completions \
}'
```
Expected response on failure:
Expected response when the guardrail blocks:
```json
{
"error": {
"message": {
"error": "Violated PANW Prisma AIRS guardrail policy",
"panw_response": {
"action": "block",
"category": "malicious",
"profile_id": "03b32734-d06d-4bb7-a8df-ac5147630ce8",
"profile_name": "dev-block-all-profile",
"prompt_detected": {
"dlp": false,
"injection": true,
"toxic_content": false,
"url_cats": false
},
"report_id": "Rbd251eac-6e67-433b-b3ef-8eb42d2c7d2c",
"response_detected": {
"dlp": false,
"toxic_content": false,
"url_cats": false
},
"scan_id": "bd251eac-6e67-433b-b3ef-8eb42d2c7d2c",
"tr_id": "string"
}
},
"type": "None",
"param": "None",
"code": "400"
"message": "Prompt blocked by PANW Prisma AI Security policy (Category: malicious)",
"type": "guardrail_violation",
"code": "panw_prisma_airs_blocked",
"guardrail": "panw-prisma-airs-guardrail",
"category": "malicious"
}
}
```
</TabItem>
<TabItem label="Successful Call" value="allowed">
LiteLLM wraps this detail in an endpoint-specific HTTP error envelope. Optional fields that may also appear: `scan_id`, `report_id`, `profile_name`, `profile_id`, `tr_id`, `prompt_detected`.
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-your-api-key" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "What is the weather like today?"}
],
"guardrails": ["panw-prisma-airs-guardrail"]
}'
```
On success, the guardrail name appears in the `x-litellm-applied-guardrails` response header.
Expected successful response:
## Configuration
```json
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "I don't have access to real-time weather data, but I can help you find weather information through various weather services or apps...",
"role": "assistant",
"tool_calls": null,
"function_call": null,
"annotations": []
}
}
],
"created": 1736028456,
"id": "chatcmpl-AqQj8example",
"model": "gpt-4o",
"object": "chat.completion",
"usage": {
"completion_tokens": 25,
"prompt_tokens": 12,
"total_tokens": 37
},
"x-litellm-panw-scan": {
"action": "allow",
"category": "benign",
"profile_id": "03b32734-d06d-4bb7-a8df-ac5147630ce8",
"profile_name": "dev-block-all-profile",
"prompt_detected": {
"dlp": false,
"injection": false,
"toxic_content": false,
"url_cats": false
},
"report_id": "Rbd251eac-6e67-433b-b3ef-8eb42d2c7d2c",
"response_detected": {
"dlp": false,
"toxic_content": false,
"url_cats": false
},
"scan_id": "bd251eac-6e67-433b-b3ef-8eb42d2c7d2c",
"tr_id": "string"
}
}
```
### Supported Modes
</TabItem>
</Tabs>
| Mode | Timing | What is scanned |
|------|--------|-----------------|
| `pre_call` | Before LLM call | Request input |
| `during_call` | Parallel with LLM call | Request input |
| `post_call` | After LLM call | Response output |
| `pre_mcp_call` | Before MCP tool execution | MCP tool input |
| `during_mcp_call` | Parallel with MCP tool execution | MCP tool input |
## Configuration Parameters
### Configuration Parameters
| Parameter | Required | Description | Default |
|-----------|----------|-------------|---------|
| `api_key` | Yes | Your PANW Prisma AIRS API key from Strata Cloud Manager | - |
| `profile_name` | No | Security profile name configured in Strata Cloud Manager. Optional if API key has linked profile | - |
| `app_name` | No | Application identifier for tracking in Prisma AIRS analytics (will be prefixed with "LiteLLM-") | `LiteLLM` |
| `api_base` | No | Regional API endpoint (see [Regional Endpoints](#regional-endpoints) below) | `https://service.api.aisecurity.paloaltonetworks.com` (US) |
| `mode` | No | When to run the guardrail | `pre_call` |
| `fallback_on_error` | No | Action when PANW API is unavailable: `"block"` (fail-closed, default) or `"allow"` (fail-open). Config errors always block. | `block` |
| `timeout` | No | PANW API call timeout in seconds (1-60) | `10.0` |
| `violation_message_template` | No | Custom template for error message when request is blocked. Supports `{guardrail_name}`, `{category}`, `{action_type}`, `{default_message}` placeholders. | - |
| `app_name` | No | Application identifier for tracking in Prisma AIRS analytics (prefixed with "LiteLLM-") | `LiteLLM` |
| `api_base` | No | Regional API endpoint. US: `https://service.api.aisecurity.paloaltonetworks.com`, EU: `https://service-de.api.aisecurity.paloaltonetworks.com`, India: `https://service-in.api.aisecurity.paloaltonetworks.com`, Singapore: `https://service-sg.api.aisecurity.paloaltonetworks.com` | US |
| `mode` | No | When to run the guardrail (see mode table above) | `pre_call` |
| `fallback_on_error` | No | Action when PANW API is unavailable: `"block"` (fail-closed) or `"allow"` (fail-open). Config errors always block. | `block` |
| `timeout` | No | PANW API call timeout in seconds (recommended: 1-60) | `10.0` |
| `violation_message_template` | No | Custom template for blocked requests. Supports `{guardrail_name}`, `{category}`, `{action_type}`, `{default_message}` placeholders. | - |
| `mask_request_content` | No | Mask sensitive data in prompts instead of blocking | `false` |
| `mask_response_content` | No | Mask sensitive data in responses instead of blocking | `false` |
| `mask_on_block` | No | Backwards-compatible flag that enables both request and response masking | `false` |
| `experimental_use_latest_role_message_only` | No | Anthropic `/v1/messages` only. When unset: scans only latest user message on request side. Set `false` to scan all user/system/developer messages. Non-Anthropic unaffected. | Unset (true for Anthropic) |
### Regional Endpoints
Use the regional `api_base` that matches your Prisma AIRS deployment profile region for lower latency and data residency compliance.
PANW Prisma AIRS supports multiple regional endpoints based on your deployment profile region:
| Region | API Base URL |
|--------|--------------|
| **US** (default) | `https://service.api.aisecurity.paloaltonetworks.com` |
| **EU (Germany)** | `https://service-de.api.aisecurity.paloaltonetworks.com` |
| **India** | `https://service-in.api.aisecurity.paloaltonetworks.com` |
**Example configuration for EU region:**
```yaml
guardrails:
- guardrail_name: "panw-eu"
litellm_params:
guardrail: panw_prisma_airs
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
api_base: "https://service-de.api.aisecurity.paloaltonetworks.com"
profile_name: "production"
```
:::tip Region Selection
Use the regional endpoint that matches your Prisma AIRS deployment profile region configured in Strata Cloud Manager. Using the correct region ensures:
- Lower latency (requests stay in-region)
- Compliance with data residency requirements
- Optimal performance
:::
## Per-Request Metadata Overrides
You can override guardrail settings on a per-request basis using the `metadata` field:
```json
{
"model": "gpt-4",
"messages": [...],
"metadata": {
"profile_name": "dev-allow-all", // Override profile name
"profile_id": "uuid-here", // Override profile ID (takes precedence)
"user_ip": "192.168.1.100", // Track user IP
"app_name": "MyApp" // Custom app name (becomes "LiteLLM-MyApp")
}
}
```
**Supported Metadata Fields:**
| Field | Description | Priority |
|-------|-------------|----------|
| `profile_name` | PANW AI security profile name | Per-request > config |
| `profile_id` | PANW AI security profile ID (takes precedence over profile_name) | Per-request only |
| `user_ip` | User IP address for tracking in Prisma AIRS | Per-request only |
| `app_name` | Application identifier (prefixed with "LiteLLM-") | Per-request > config > "LiteLLM" |
| `app_user` | Custom user identifier for tracking in Prisma AIRS | `app_user` > `user` > "litellm_user" |
:::info Profile Resolution
- If both `profile_id` and `profile_name` are provided, PANW API uses `profile_id` (it takes precedence)
- If no profile is specified in metadata, uses the config `profile_name`
- If no profile is specified at all, PANW API will use the profile linked to your API key in Strata Cloud Manager
- **Note:** If your API key is not linked to a profile, you must provide `profile_name` or `profile_id`
:::
## Multi-Turn Conversation Tracking
PANW Prisma AIRS automatically tracks multi-turn conversations using LiteLLM's `litellm_trace_id`. This enables you to:
- **Group related requests** - All requests in a conversation share the same AI Session ID in Prisma AIRS SCM logs
- **Track conversation context** - See the full history of prompts and responses for a user session
- **Analyze attack patterns** - Identify sophisticated multi-turn attacks across conversation history
### How It Works
LiteLLM automatically generates a unique `litellm_trace_id` for each conversation session. The PANW guardrail uses this as the PANW transaction ID (which maps to "AI Session ID" in Strata Cloud Manager):
```
Conversation Session: litellm_trace_id = "abc-123-def-456"
Turn 1 (User): "What's the capital of France?"
→ Scan ID: scan_001 | Prisma AIRS AI Session ID: abc-123-def-456
Turn 2 (Assistant): "Paris is the capital of France."
→ Scan ID: scan_002 | Prisma AIRS AI Session ID: abc-123-def-456
Turn 3 (User): "What's the population?"
→ Scan ID: scan_003 | Prisma AIRS AI Session ID: abc-123-def-456
Turn 4 (Assistant): "Paris has approximately 2.1 million residents."
→ Scan ID: scan_004 | Prisma AIRS AI Session ID: abc-123-def-456
```
All scans appear under the same AI Session ID in Prisma AIRS logs, making it easy to:
- Review complete conversation history (all 4 turns grouped together)
- Identify patterns across multiple turns
- Correlate security events within a session
- Track the flow of user prompts and AI responses
### Session Tracking
LiteLLM automatically generates a unique `litellm_trace_id` for each request, which the PANW guardrail uses as the AI Session ID in Strata Cloud Manager. All prompt and response scans for a request are automatically grouped under the same session.
#### Custom Session IDs (Per-App Tracking)
You can provide your own `litellm_trace_id` to track sessions on a per-app or per-conversation basis:
```bash
curl -X POST http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "capital of France"}],
"litellm_trace_id": "my-app-session-123", # Custom AI Session ID
"metadata": {
"profile_name": "dev-allow-all-profile", # Override security profile
"user_ip": "192.168.1.1", # Track user IP
"app_name": "eng" # Custom app identifier
},
"guardrails": ["panw-prisma-airs-pre-guard", "panw-prisma-airs-post-guard"]
}'
```
**Result in PANW SCM:**
- AI Session ID: `my-app-session-123`
- All prompt and response scans will be grouped under this custom session ID
- Perfect for tracking multi-turn conversations or per-application sessions
:::tip Viewing Sessions in Prisma AIRS SCM Logs
In Strata Cloud Manager, navigate to **AI Runtime > Sessions** to view all AI Session IDs and their associated scans. Click on a session to see the complete conversation history with security analysis.
:::
## Environment Variables
### Environment Variables
```bash
export PANW_PRISMA_AIRS_API_KEY="your-panw-api-key"
@@ -348,12 +136,31 @@ export PANW_PRISMA_AIRS_PROFILE_NAME="your-security-profile"
export PANW_PRISMA_AIRS_API_BASE="https://custom-endpoint.com"
```
## Advanced Configuration
### Per-Request Metadata Overrides
| Field | Description | Priority |
|-------|-------------|----------|
| `profile_name` | PANW AI security profile name | Per-request > config |
| `profile_id` | PANW AI security profile ID (takes precedence over `profile_name`) | Per-request only |
| `user_ip` | User IP address for tracking in Prisma AIRS | Per-request only |
| `app_name` | Application identifier (prefixed with "LiteLLM-") | Per-request > config > "LiteLLM" |
| `app_user` | Custom user identifier for tracking in Prisma AIRS | `app_user` > `user` > "litellm_user" |
```json
{
"model": "gpt-4",
"messages": [...],
"metadata": {
"profile_name": "dev-allow-all",
"profile_id": "uuid-here",
"user_ip": "192.168.1.100",
"app_name": "MyApp"
}
}
```
### Multiple Security Profiles
You can configure different security profiles for different use cases:
```yaml
guardrails:
- guardrail_name: "panw-strict-security"
@@ -361,126 +168,40 @@ guardrails:
guardrail: panw_prisma_airs
mode: "pre_call"
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
profile_name: "strict-policy" # High security profile
- guardrail_name: "panw-permissive-security"
profile_name: "strict-policy"
- guardrail_name: "panw-permissive-security"
litellm_params:
guardrail: panw_prisma_airs
mode: "post_call"
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
profile_name: "permissive-policy" # Lower security profile
profile_name: "permissive-policy"
```
### Multiple API Keys (Multi-Tenant)
For multi-tenant deployments where different customers need different PANW API keys, create separate guardrail instances:
```yaml
guardrails:
- guardrail_name: "panw-customer-a"
litellm_params:
guardrail: panw_prisma_airs
mode: "pre_call"
api_key: os.environ/PANW_CUSTOMER_A_KEY # Linked to Customer A profile in SCM
- guardrail_name: "panw-customer-b"
litellm_params:
guardrail: panw_prisma_airs
mode: "pre_call"
api_key: os.environ/PANW_CUSTOMER_B_KEY # Linked to Customer B profile in SCM
```
Then route requests to the appropriate guardrail:
```bash
curl -X POST http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"guardrails": ["panw-customer-a"]
}'
```
**Use Cases:**
- **Multi-tenant deployments**: Different customers with different security policies
- **Environment-specific policies**: Dev/staging/prod with different API keys and profiles
- **A/B testing**: Compare different security profiles side-by-side
### Content Masking
PANW Prisma AIRS can automatically mask sensitive content (PII, credit cards, SSNs, etc.) instead of blocking requests. This allows your application to continue functioning while protecting sensitive data.
#### How It Works
1. **Detection**: PANW scans content and identifies sensitive data
2. **Masking**: Sensitive data is replaced with placeholders (e.g., `XXXXXXXXXX` or `{PHONE}`)
3. **Pass-through**: Masked content is sent to the LLM or returned to the user
#### Configuration Options
:::warning Important: Masking is Controlled by PANW Security Profile
The actual masking behavior (what content gets masked and how) is controlled by your PANW Prisma AIRS security profile in Strata Cloud Manager. The LiteLLM flags (`mask_request_content`, `mask_response_content`) only control whether to apply the masked content and allow the request to continue, or block entirely.
:::
```yaml
guardrails:
- guardrail_name: "panw-with-masking"
litellm_params:
guardrail: panw_prisma_airs
mode: "post_call" # Scan response output
mode: "post_call"
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
profile_name: "default"
mask_request_content: true # Mask sensitive data in prompts
mask_response_content: true # Mask sensitive data in responses
mask_request_content: true
mask_response_content: true
```
**Masking Parameters:**
- `mask_request_content: true` - When PANW detects sensitive data in prompts, mask it instead of blocking
- `mask_response_content: true` - When PANW detects sensitive data in responses, mask it instead of blocking
- `mask_on_block: true` - Backwards compatible flag that enables both request and response masking
:::warning Important: Masking is Controlled by PANW Security Profile
The **actual masking behavior** (what content gets masked and how) is controlled by your **PANW Prisma AIRS security profile** configured in Strata Cloud Manager. The LiteLLM config settings (`mask_request_content`, `mask_response_content`) only control whether to:
- **Apply the masked content** returned by PANW and allow the request to continue, OR
- **Block the request** entirely when sensitive data is detected
LiteLLM does not alter or configure your PANW security profile. To change what content gets masked, update your profile settings in Strata Cloud Manager.
:::
:::info Security Posture
The guardrail is **fail-closed** by default - if the PANW API is unavailable, requests are blocked to ensure no unscanned content reaches your LLM. This provides maximum security.
:::
### Custom Violation Messages
You can customize the error message returned to the user when a request is blocked by configuring the `violation_message_template` parameter. This is useful for providing user-friendly feedback instead of technical details.
```yaml
guardrails:
- guardrail_name: "panw-custom-message"
litellm_params:
guardrail: panw_prisma_airs
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
# Simple message
violation_message_template: "Your request was blocked by our AI Security Policy."
- guardrail_name: "panw-detailed-message"
litellm_params:
guardrail: panw_prisma_airs
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
# Message with placeholders
violation_message_template: "{action_type} blocked due to {category} violation. Please contact support."
```
**Supported Placeholders:**
- `{guardrail_name}`: Name of the guardrail (e.g. "panw-custom-message")
- `{category}`: Violation category (e.g. "malicious", "injection", "dlp")
- `{action_type}`: "Prompt" or "Response"
- `{default_message}`: The original technical error message
- `mask_request_content: true` — mask sensitive data in prompts instead of blocking
- `mask_response_content: true` — mask sensitive data in responses instead of blocking
- `mask_on_block: true` — backwards-compatible flag that enables both request and response masking
### Fail-Open Configuration
By default, the PANW guardrail operates in **fail-closed** mode for maximum security. If the PANW API is unavailable (timeout, rate limit, network error), requests are blocked. You can configure **fail-open** mode for high-availability scenarios where service continuity is critical.
```yaml
guardrails:
- guardrail_name: "panw-high-availability"
@@ -488,135 +209,86 @@ guardrails:
guardrail: panw_prisma_airs
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
profile_name: "production"
fallback_on_error: "allow" # Enable fail-open mode
timeout: 5.0 # Shorter timeout for fail-open
fallback_on_error: "allow"
timeout: 5.0
```
**Configuration Options:**
| Parameter | Value | Behavior |
|-----------|-------|----------|
| `fallback_on_error` | `"block"` (default) | **Fail-closed**: Block requests when API unavailable (maximum security) |
| `fallback_on_error` | `"allow"` | **Fail-open**: Allow requests when API unavailable (high availability) |
| `timeout` | `1.0` - `60.0` | API call timeout in seconds (default: `10.0`) |
**Error Handling Matrix:**
| Error Type | `fallback_on_error="block"` | `fallback_on_error="allow"` |
|------------|----------------------------|----------------------------|
| 401 Unauthorized | Block (500) | Block (500) ⚠️ |
| 403 Forbidden | Block (500) | Block (500) ⚠️ |
| Profile Error | Block (500) | Block (500) ⚠️ |
| 401 Unauthorized | Block (500) | Block (500) |
| 403 Forbidden | Block (500) | Block (500) |
| Profile Error | Block (500) | Block (500) |
| 429 Rate Limit | Block (500) | Allow (`:unscanned`) |
| Timeout | Block (500) | Allow (`:unscanned`) |
| Network Error | Block (500) | Allow (`:unscanned`) |
| 5xx Server Error | Block (500) | Allow (`:unscanned`) |
| Content Blocked | Block (400) | Block (400) |
⚠️ = Always blocks regardless of fail-open setting
Authentication and configuration errors (401, 403, invalid profile) always block. Only transient errors (429, timeout, network) trigger fail-open.
:::warning Security Trade-Off
Enabling `fallback_on_error="allow"` reduces security in exchange for availability. Requests may proceed **without scanning** when the PANW API is unavailable. Use only when:
- Service availability is more critical than security scanning
- You have other security controls in place
- You monitor the `:unscanned` header for audit trails
When fail-open is triggered, the response includes a tracking header: `X-LiteLLM-Applied-Guardrails: panw-airs:unscanned`
**Authentication and configuration errors (401, 403, invalid profile) always block** - only transient errors (429, timeout, network) trigger fail-open behavior.
:::
**Observability:**
When fail-open is triggered, the response includes a special header for tracking:
```
X-LiteLLM-Applied-Guardrails: panw-airs:unscanned
```
This allows you to:
- Track which requests bypassed scanning
- Alert on unscanned request volumes
- Audit compliance requirements
#### Example: Masking Credit Card Numbers
<Tabs>
<TabItem label="Without Masking" value="no-mask">
**Request:**
```json
{
"messages": [
{"role": "user", "content": "My credit card is 4929-3813-3266-4295"}
]
}
```
**Response:** ❌ **Blocked with 400 error**
</TabItem>
<TabItem label="With Masking" value="with-mask">
**Request:**
```json
{
"messages": [
{"role": "user", "content": "My credit card is 4929-3813-3266-4295"}
]
}
```
**Masked prompt sent to LLM:**
```json
{
"messages": [
{"role": "user", "content": "My credit card is XXXXXXXXXXXXXXXXXX"}
]
}
```
**Response:** ✅ **Allowed with masked content**
</TabItem>
</Tabs>
#### Masking Capabilities
The guardrail masks sensitive content in:
- ✅ **Chat messages** - User prompts and assistant responses
- ✅ **Streaming responses** - Real-time masking of streamed content
- ✅ **Multi-choice responses** - All choices in the response
- ✅ **Tool/function calls** - Arguments passed to tools and functions
- ✅ **Content lists** - Mixed content types (text, images, etc.)
#### Complete Example
### Custom Violation Messages
```yaml
guardrails:
- guardrail_name: "panw-production-security"
- guardrail_name: "panw-custom-message"
litellm_params:
guardrail: panw_prisma_airs
mode: "post_call" # Scan input and output
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
profile_name: "production-profile"
mask_request_content: true # Mask sensitive prompts
mask_response_content: true # Mask sensitive responses
violation_message_template: "Your request was blocked by our AI Security Policy."
- guardrail_name: "panw-detailed-message"
litellm_params:
guardrail: panw_prisma_airs
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
violation_message_template: "{action_type} blocked due to {category} violation. Please contact support."
```
## Use Cases
**Supported Placeholders:** `{guardrail_name}`, `{category}`, `{action_type}`, `{default_message}`
From [official Prisma AIRS documentation](https://docs.paloaltonetworks.com/ai-runtime-security/activation-and-onboarding/ai-runtime-security-api-intercept-overview):
## Behavior and Limitations
- **Secure AI models in production**: Validate prompt requests and responses to protect deployed AI models
- **Detect data poisoning**: Identify contaminated training data before fine-tuning
- **Protect against adversarial input**: Safeguard AI agents from malicious inputs and outputs
- **Prevent sensitive data leakage**: Use API-based threat detection to block sensitive data leaks
### Transaction Tracking
For standard request/response scans, `tr_id` maps to `litellm_call_id`. MCP tool scans use the parent `litellm_call_id` when available; if missing, PANW synthesizes a fallback MCP transaction ID. The real limitation is correlation loss — synthesized MCP `tr_id` values are not grouped with the parent request's prompt/response scans in AIRS dashboards.
By default, LiteLLM generates a UUID for `litellm_call_id`. To provide your own:
```bash
curl -X POST http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-H "x-litellm-call-id: my-custom-call-id-789" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "capital of France"}],
"guardrails": ["panw-prisma-airs-guardrail"]
}'
```
The `x-litellm-call-id` is also returned in response headers. If you pass `litellm_trace_id` in request metadata (or via the `x-litellm-trace-id` header), it is included in the PANW API payload metadata but does not affect `tr_id` or appear in Prisma AIRS.
### Streaming
- Response masking works on OpenAI chat streaming (`mask_response_content: true`)
- `/v1/messages` and `/v1/responses` raw streaming blocks instead of masking when violations are detected
- Request-side masking (`mask_request_content`) is unaffected by endpoint type
- When `fallback_on_error: "allow"` is set, streaming responses fail open on transient PANW API errors (timeout, 5xx, network) — original chunks are yielded unchanged
## MCP Tool Security
Tool invocations are sent to AIRS as structured `tool_event` payloads containing tool name, ecosystem, and serialized arguments. Tool-event scans always use request mode.
**What is scanned:** LLM-driven `tool_calls` (name + arguments) and MCP request-side invocations when `mcp_tool_name` (or fallback `name`) is present. Response-side OpenAI-compatible `tool_calls` are also scanned when surfaced into `apply_guardrail()`.
**What is not scanned:** Tool definitions in `inputs["tools"]` and post-MCP tool results (no `post_mcp_call` hook exists yet).
## Next Steps
### Current Limitations
- Configure your security policies in [Strata Cloud Manager](https://apps.paloaltonetworks.com/)
- Review the [Prisma AIRS API documentation](https://pan.dev/airs/) for advanced features
- Set up monitoring and alerting for threat detections in your PANW dashboard
- Consider implementing both pre_call and post_call guardrails for comprehensive protection
- Monitor detection events and tune your security profiles based on your application needs
- **No post-MCP response scanning.** Actual post-MCP tool-result scanning is not supported because there is no `post_mcp_call` hook in the framework. Response-side MCP events are only scanned when they appear as regular `tool_calls` in the LLM response.
- **Guardrail selection not inherited by MCP sub-calls.** With `default_on: false`, MCP request-side child-call scans can be skipped because the parent request's guardrail selection is not propagated to the synthetic MCP payload. Workaround: use a dedicated guardrail with `mode: pre_mcp_call` and `default_on: true`.
- **MCP transaction correlation.** MCP tool scans use the parent `litellm_call_id` when available; otherwise a fallback ID is synthesized and will not be grouped with the parent request in AIRS dashboards.
Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

+2
View File
@@ -814,6 +814,8 @@ const sidebars = {
"providers/anyscale",
"providers/apertis",
"providers/baseten",
"providers/black_forest_labs",
"providers/black_forest_labs_img_edit",
"providers/bytez",
"providers/cerebras",
"providers/chutes",
@@ -0,0 +1,13 @@
-- SkipTransactionBlock
-- Drop invalid indexes left behind by failed CONCURRENTLY builds
DROP INDEX CONCURRENTLY IF EXISTS "LiteLLM_VerificationToken_key_alias_idx";
-- CreateIndex
CREATE INDEX CONCURRENTLY "LiteLLM_VerificationToken_key_alias_idx" ON "LiteLLM_VerificationToken"("key_alias");
-- Drop invalid indexes left behind by failed CONCURRENTLY builds
DROP INDEX CONCURRENTLY IF EXISTS "LiteLLM_SpendLogs_user_startTime_idx";
-- CreateIndex
CREATE INDEX CONCURRENTLY "LiteLLM_SpendLogs_user_startTime_idx" ON "LiteLLM_SpendLogs"("user", "startTime");
@@ -388,6 +388,9 @@ model LiteLLM_VerificationToken {
// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3
@@index([budget_reset_at, expires])
// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (...) ORDER BY "public"."LiteLLM_VerificationToken"."key_alias" ASC
@@index([key_alias])
}
model LiteLLM_JWTKeyMapping {
@@ -553,6 +556,9 @@ model LiteLLM_SpendLogs {
@@index([startTime, request_id])
@@index([end_user])
@@index([session_id])
// SELECT ... FROM "LiteLLM_SpendLogs" WHERE ("startTime" >= $1 AND "startTime" <= $2 AND "user" = $3) GROUP BY ...
@@index([user, startTime])
}
// View spend, model, api_key per request
+5
View File
@@ -577,6 +577,7 @@ v0_models: Set = set()
morph_models: Set = set()
lambda_ai_models: Set = set()
hyperbolic_models: Set = set()
black_forest_labs_models: Set = set()
recraft_models: Set = set()
cometapi_models: Set = set()
oci_models: Set = set()
@@ -824,6 +825,8 @@ def add_known_models(model_cost_map: Optional[Dict] = None):
lambda_ai_models.add(key)
elif value.get("litellm_provider") == "hyperbolic":
hyperbolic_models.add(key)
elif value.get("litellm_provider") == "black_forest_labs":
black_forest_labs_models.add(key)
elif value.get("litellm_provider") == "recraft":
recraft_models.add(key)
elif value.get("litellm_provider") == "cometapi":
@@ -957,6 +960,7 @@ model_list = list(
| v0_models
| morph_models
| lambda_ai_models
| black_forest_labs_models
| recraft_models
| cometapi_models
| oci_models
@@ -1055,6 +1059,7 @@ models_by_provider: dict = {
"morph": morph_models,
"lambda_ai": lambda_ai_models,
"hyperbolic": hyperbolic_models,
"black_forest_labs": black_forest_labs_models,
"recraft": recraft_models,
"cometapi": cometapi_models,
"oci": oci_models,
-4
View File
@@ -348,8 +348,6 @@ class DualCache(BaseCache):
)
try:
if self.in_memory_cache is not None:
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
kwargs["ttl"] = self.default_in_memory_ttl
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
@@ -371,8 +369,6 @@ class DualCache(BaseCache):
)
try:
if self.in_memory_cache is not None:
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
kwargs["ttl"] = self.default_in_memory_ttl
await self.in_memory_cache.async_set_cache_pipeline(
cache_list=cache_list, **kwargs
)
@@ -398,9 +398,6 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
ResponseOutputMessage,
ResponseReasoningItem,
)
from openai.types.responses.response_output_item import (
ResponseApplyPatchToolCall,
)
from litellm.types.utils import Choices, Message
@@ -457,18 +454,6 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
accumulated_tool_calls.append(tool_call_dict)
tool_call_index += 1
elif isinstance(item, ResponseApplyPatchToolCall):
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
tool_call_dict = LiteLLMCompletionResponsesConfig.convert_apply_patch_tool_call_to_chat_completion_tool_call(
tool_call_item=item,
index=tool_call_index,
)
accumulated_tool_calls.append(tool_call_dict)
tool_call_index += 1
elif isinstance(item, dict) and handle_raw_dict_callback is not None:
# Handle raw dict responses (e.g., from GPT-5 Codex)
choice, index = handle_raw_dict_callback(item=item, index=index)
+1 -5
View File
@@ -1212,12 +1212,8 @@ OPENAI_FINISH_REASONS = [
"stop",
"length",
"function_call",
"tool_calls",
"content_filter",
"null",
"finish_reason_unspecified",
"malformed_function_call",
"guardrail_intervened",
"eos",
]
HUMANLOOP_PROMPT_CACHE_TTL_SECONDS = int(
os.getenv("HUMANLOOP_PROMPT_CACHE_TTL_SECONDS", 60)
@@ -770,8 +770,6 @@ class GoogleGenAIAdapter:
"content_filter": "SAFETY",
"tool_calls": "STOP",
"function_call": "STOP",
"finish_reason_unspecified": "FINISH_REASON_UNSPECIFIED",
"malformed_function_call": "MALFORMED_FUNCTION_CALL",
}
return mapping.get(finish_reason, "STOP")
+42 -5
View File
@@ -50,6 +50,10 @@ from litellm.main import (
openai_image_variations,
)
# BFL handlers
from litellm.llms.black_forest_labs.image_edit.handler import bfl_image_edit
from litellm.llms.black_forest_labs.image_generation.handler import bfl_image_generation
###########################################
from litellm.secret_managers.main import get_secret_str
from litellm.types.images.main import ImageEditOptionalRequestParams
@@ -426,6 +430,22 @@ def image_generation( # noqa: PLR0915
timeout=timeout,
client=client,
)
elif custom_llm_provider == "black_forest_labs":
# Route to BFL-specific handler (polling required)
if model is None:
raise Exception("Model needs to be set for black_forest_labs")
return bfl_image_generation.image_generation(
model=model,
prompt=prompt,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params_dict,
logging_obj=litellm_logging_obj,
timeout=timeout,
extra_headers=extra_headers,
client=client,
aimg_generation=aimg_generation,
)
elif custom_llm_provider == "azure_ai":
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
@@ -909,19 +929,36 @@ def image_edit( # noqa: PLR0915
elif custom_llm_provider == "stability":
image_edit_request_params.update(non_default_params)
return base_llm_http_handler.image_edit_handler(
model=model,
image=images,
prompt=prompt,
image_edit_provider_config=image_edit_provider_config,
image_edit_optional_request_params=image_edit_request_params,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
logging_obj=litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or DEFAULT_REQUEST_TIMEOUT,
_is_async=_is_async,
client=kwargs.get("client"),
)
elif custom_llm_provider == "black_forest_labs":
# Route to BFL-specific handler (polling required)
if model is None:
raise Exception("Model needs to be set for black_forest_labs")
image_edit_request_params.update(non_default_params)
return bfl_image_edit.image_edit(
model=model,
image=images,
prompt=prompt,
image_edit_provider_config=image_edit_provider_config,
image_edit_optional_request_params=image_edit_request_params,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
logging_obj=litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or DEFAULT_REQUEST_TIMEOUT,
_is_async=_is_async,
extra_headers=extra_headers,
client=kwargs.get("client"),
aimage_edit=_is_async,
)
# Call the handler with _is_async flag instead of directly calling the async handler
return base_llm_http_handler.image_edit_handler(
+50 -40
View File
@@ -1,11 +1,11 @@
# What is this?
## Helper utilities
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union, get_args
import httpx
from litellm._logging import verbose_logger
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.openai import AllMessageValues, OpenAIChatCompletionFinishReason
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@@ -58,45 +58,55 @@ def safe_divide(
return numerator / denominator
def map_finish_reason(
finish_reason: str,
): # openai supports 5 stop sequences - 'stop', 'length', 'function_call', 'content_filter', 'null'
# anthropic mapping
if finish_reason == "stop_sequence":
_FINISH_REASON_MAP: dict[str, OpenAIChatCompletionFinishReason] = {
# Anthropic
"stop_sequence": "stop",
"end_turn": "stop",
"max_tokens": "length",
"tool_use": "tool_calls",
"compaction": "length",
# Cohere
"COMPLETE": "stop",
"ERROR_TOXIC": "content_filter",
"ERROR": "stop",
# HuggingFace / Together AI
"eos_token": "stop",
"eos": "stop",
# Gemini / Vertex AI
"STOP": "stop",
"MAX_TOKENS": "length",
"SAFETY": "content_filter",
"RECITATION": "content_filter",
"FINISH_REASON_UNSPECIFIED": "stop",
"MALFORMED_FUNCTION_CALL": "stop",
"LANGUAGE": "content_filter",
"OTHER": "content_filter",
"BLOCKLIST": "content_filter",
"PROHIBITED_CONTENT": "content_filter",
"SPII": "content_filter",
"IMAGE_SAFETY": "content_filter",
"IMAGE_PROHIBITED_CONTENT": "content_filter",
"TOO_MANY_TOOL_CALLS": "stop",
"MALFORMED_RESPONSE": "stop",
# Bedrock
"guardrail_intervened": "content_filter",
# OpenAI passthrough
"stop": "stop",
"length": "length",
"tool_calls": "tool_calls",
"function_call": "function_call",
"content_filter": "content_filter",
}
def map_finish_reason(finish_reason: str) -> OpenAIChatCompletionFinishReason:
mapped = _FINISH_REASON_MAP.get(finish_reason)
if mapped is None:
verbose_logger.warning(
"Unmapped finish_reason '%s', defaulting to 'stop'", finish_reason
)
return "stop"
# cohere mapping - https://docs.cohere.com/reference/generate
elif finish_reason == "COMPLETE":
return "stop"
elif finish_reason == "MAX_TOKENS": # cohere + vertex ai
return "length"
elif finish_reason == "ERROR_TOXIC":
return "content_filter"
elif (
finish_reason == "ERROR"
): # openai currently doesn't support an 'error' finish reason
return "stop"
# huggingface mapping https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate_stream
elif finish_reason == "eos_token" or finish_reason == "stop_sequence":
return "stop"
elif (
finish_reason == "FINISH_REASON_UNSPECIFIED"
): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',]
return "finish_reason_unspecified"
elif finish_reason == "MALFORMED_FUNCTION_CALL":
return "malformed_function_call"
elif finish_reason == "SAFETY" or finish_reason == "RECITATION": # vertex ai
return "content_filter"
elif finish_reason == "STOP": # vertex ai
return "stop"
elif finish_reason == "end_turn" or finish_reason == "stop_sequence": # anthropic
return "stop"
elif finish_reason == "max_tokens": # anthropic
return "length"
elif finish_reason == "tool_use": # anthropic
return "tool_calls"
elif finish_reason == "compaction":
return "length"
return finish_reason
return mapped
def remove_index_from_tool_calls(
@@ -64,10 +64,12 @@ def duration_in_seconds(duration: str) -> int:
now = time.time()
current_time = datetime.fromtimestamp(now)
# Calculate target month and year, handling overflow past December
total_months = current_time.month - 1 + value # 0-indexed months
target_year = current_time.year + total_months // 12
target_month = total_months % 12 + 1 # back to 1-indexed
if current_time.month == 12:
target_year = current_time.year + 1
target_month = 1
else:
target_year = current_time.year
target_month = current_time.month + value
# Determine the day to set for next month
target_day = current_time.day
@@ -11,7 +11,7 @@ export LITELLM_LOCAL_MODEL_COST_MAP=True
import json
import os
from importlib.resources import files
from typing import Optional
from typing import Dict, List, Optional
import httpx
@@ -186,6 +186,61 @@ def get_model_cost_map_source_info() -> dict:
}
def _expand_model_aliases(model_cost: dict) -> dict:
"""
Expand ``aliases`` lists in model cost entries into top-level entries.
Each alias gets a reference to the **same** dict object as the canonical
entry (zero memory overhead). The ``aliases`` key is removed from the
entry so downstream code never sees it.
If an alias collides with an existing canonical entry the alias is
skipped and a warning is logged.
"""
aliases_to_add: Dict[str, dict] = {}
keys_with_aliases: List[str] = []
for model_name, model_info in model_cost.items():
aliases: Optional[list] = model_info.get("aliases")
if aliases is None:
continue
keys_with_aliases.append(model_name)
if not isinstance(aliases, list):
verbose_logger.warning(
"LiteLLM model alias field for '%s' is not a list (got %s) — skipping.",
model_name,
type(aliases).__name__,
)
continue
if not aliases:
continue
for alias in aliases:
if alias in model_cost:
verbose_logger.warning(
"LiteLLM model alias conflict: alias '%s' (from '%s') "
"already exists as a canonical entry — skipping.",
alias,
model_name,
)
continue
if alias in aliases_to_add:
verbose_logger.warning(
"LiteLLM model alias conflict: alias '%s' (from '%s') "
"was already claimed by another entry — skipping.",
alias,
model_name,
)
continue
aliases_to_add[alias] = model_info # same dict reference
# Remove the ``aliases`` key from entries so it doesn't pollute model info
for key in keys_with_aliases:
model_cost[key].pop("aliases", None)
model_cost.update(aliases_to_add)
return model_cost
def get_model_cost_map(url: str) -> dict:
"""
Public entry point returns the model cost map dict.
@@ -205,7 +260,7 @@ def get_model_cost_map(url: str) -> dict:
_cost_map_source_info.url = None
_cost_map_source_info.is_env_forced = True
_cost_map_source_info.fallback_reason = None
return GetModelCostMap.load_local_model_cost_map()
return _expand_model_aliases(GetModelCostMap.load_local_model_cost_map())
_cost_map_source_info.url = url
_cost_map_source_info.is_env_forced = False
@@ -221,7 +276,7 @@ def get_model_cost_map(url: str) -> dict:
)
_cost_map_source_info.source = "local"
_cost_map_source_info.fallback_reason = f"Remote fetch failed: {str(e)}"
return GetModelCostMap.load_local_model_cost_map()
return _expand_model_aliases(GetModelCostMap.load_local_model_cost_map())
# Validate using cached count (cheap int comparison, no file I/O)
if not GetModelCostMap.validate_model_cost_map(
@@ -234,11 +289,9 @@ def get_model_cost_map(url: str) -> dict:
url,
)
_cost_map_source_info.source = "local"
_cost_map_source_info.fallback_reason = (
"Remote data failed integrity validation"
)
return GetModelCostMap.load_local_model_cost_map()
_cost_map_source_info.fallback_reason = "Remote data failed integrity validation"
return _expand_model_aliases(GetModelCostMap.load_local_model_cost_map())
_cost_map_source_info.source = "remote"
_cost_map_source_info.fallback_reason = None
return content
return _expand_model_aliases(content)
@@ -150,6 +150,14 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.MistralConfig().get_supported_openai_params(model=model)
elif request_type == "embeddings":
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
elif request_type == "transcription":
from litellm.llms.mistral.audio_transcription.transformation import (
MistralAudioTranscriptionConfig,
)
return MistralAudioTranscriptionConfig().get_supported_openai_params(
model=model
)
elif custom_llm_provider == "text-completion-codestral":
return litellm.CodestralTextCompletionConfig().get_supported_openai_params(
model=model
@@ -2500,74 +2500,257 @@ def anthropic_messages_pt( # noqa: PLR0915
assistant_content.extend(_compaction_blocks) # type: ignore
thinking_blocks = assistant_content_block.get("thinking_blocks", None)
# Check if tool_calls contain server tool calls (web search, etc.)
# If so, we need to interleave thinking blocks with tool call groups
# to preserve the original content block ordering.
# Fixes: https://github.com/BerriAI/litellm/issues/23047
assistant_tool_calls = assistant_content_block.get("tool_calls")
_has_server_tool_calls = False
if assistant_tool_calls is not None:
for _tc in assistant_tool_calls:
_tc_id = (
_tc.get("id")
if isinstance(_tc, dict)
else getattr(_tc, "id", None)
)
if _tc_id and isinstance(_tc_id, str) and _tc_id.startswith("srvtoolu_"):
_has_server_tool_calls = True
break
if (
thinking_blocks is not None
): # IMPORTANT: ADD THIS FIRST, ELSE ANTHROPIC WILL RAISE AN ERROR
assistant_content.extend(thinking_blocks)
if "content" in assistant_content_block and isinstance(
assistant_content_block["content"], list
and _has_server_tool_calls
and isinstance(
assistant_content_block.get("content", None), (str, type(None))
)
):
for m in assistant_content_block["content"]:
# handle thinking blocks
thinking_block = cast(str, m.get("thinking", ""))
text_block = cast(str, m.get("text", ""))
if (
m.get("type", "") == "thinking" and len(thinking_block) > 0
): # don't pass empty text blocks. anthropic api raises errors.
anthropic_message: Union[
ChatCompletionThinkingBlock,
AnthropicMessagesTextParam,
] = cast(ChatCompletionThinkingBlock, m)
assistant_content.append(anthropic_message)
# handle text
elif (
m.get("type", "") == "text" and len(text_block) > 0
): # don't pass empty text blocks. anthropic api raises errors.
anthropic_message = AnthropicMessagesTextParam(
type="text", text=text_block
)
_cached_message = add_cache_control_to_content(
anthropic_content_element=anthropic_message,
original_content_element=dict(m),
)
# INTERLEAVED MODE: When we have both thinking blocks and server
# tool calls (e.g. web search), Anthropic's original response
# interleaves them: [thinking_1, server_tool_use_1, result_1,
# thinking_2, text, server_tool_use_2, result_2, ...].
# We must preserve this interleaved order because Anthropic
# verifies thinking block signatures based on position.
assistant_content.append(
cast(AnthropicMessagesTextParam, _cached_message)
)
# handle server_tool_use blocks (tool search, web search, etc.)
# Pass through as-is since these are Anthropic-native content types
elif m.get("type", "") == "server_tool_use":
assistant_content.append(m) # type: ignore
# handle all *_tool_result blocks (tool_search_tool_result,
# web_search_tool_result, bash_code_execution_tool_result, etc.)
# Pass through as-is since these are Anthropic-native content types
elif m.get("type", "").endswith("_tool_result"):
assistant_content.append(m) # type: ignore
elif (
"content" in assistant_content_block
and isinstance(assistant_content_block["content"], str)
and assistant_content_block[
"content"
] # don't pass empty text blocks. anthropic api raises errors.
):
_anthropic_text_content_element = AnthropicMessagesTextParam(
type="text",
text=assistant_content_block["content"],
# Build the tool call groups (server_tool_use + its result)
_provider_specific_fields_raw_tc = assistant_content_block.get(
"provider_specific_fields"
)
_provider_specific_fields_tc: Dict[str, Any] = {}
if isinstance(_provider_specific_fields_raw_tc, dict):
_provider_specific_fields_tc = cast(
Dict[str, Any], _provider_specific_fields_raw_tc
)
_web_search_results_tc = _provider_specific_fields_tc.get(
"web_search_results"
)
_tool_results_tc = _provider_specific_fields_tc.get("tool_results")
tool_invoke_results = convert_to_anthropic_tool_invoke(
assistant_tool_calls, # type: ignore
web_search_results=_web_search_results_tc,
tool_results=_tool_results_tc,
)
_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
original_content_element=dict(assistant_content_block),
# Group tool invoke results into (server_tool_use, result) pairs
# and separate regular tool_use blocks
server_tool_groups: List[List[Any]] = []
regular_tool_uses: List[Any] = []
_current_group: List[Any] = []
for item in tool_invoke_results:
item_type = (
item.get("type", "")
if isinstance(item, dict)
else getattr(item, "type", "")
)
if item_type == "server_tool_use":
if _current_group:
server_tool_groups.append(_current_group)
_current_group = [item]
elif item_type.endswith("_tool_result"):
_current_group.append(item)
elif item_type == "tool_use":
regular_tool_uses.append(item)
else:
_current_group.append(item)
if _current_group:
server_tool_groups.append(_current_group)
# Build the text block if content is a non-empty string
text_element = None
if (
isinstance(assistant_content_block.get("content"), str)
and assistant_content_block["content"]
):
_anthropic_text_content_element = AnthropicMessagesTextParam(
type="text",
text=assistant_content_block["content"],
)
_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
original_content_element=dict(assistant_content_block),
)
if "cache_control" in _content_element:
_anthropic_text_content_element["cache_control"] = (
_content_element["cache_control"]
)
text_element = _anthropic_text_content_element
# Interleave: each thinking block precedes its server tool group.
# Pattern: thinking[0], group[0], thinking[1], group[1], ...
# Any remaining thinking blocks (after all groups) go before text.
# Any remaining groups (after all thinking blocks) go after.
tb_idx = 0
grp_idx = 0
num_tb = len(thinking_blocks) if thinking_blocks else 0
num_grp = len(server_tool_groups)
while tb_idx < num_tb or grp_idx < num_grp:
if tb_idx < num_tb and grp_idx < num_grp:
# Emit thinking block then its tool group
assistant_content.append(thinking_blocks[tb_idx])
tb_idx += 1
for block in server_tool_groups[grp_idx]:
item_id = (
block.get("id")
if isinstance(block, dict)
else getattr(block, "id", None)
)
if item_id and item_id in unique_tool_ids:
continue
if item_id:
unique_tool_ids.add(item_id)
assistant_content.append(
cast(AnthropicMessagesAssistantMessageValues, block)
)
grp_idx += 1
elif tb_idx < num_tb:
# More thinking blocks than tool groups - emit before text
assistant_content.append(thinking_blocks[tb_idx])
tb_idx += 1
else:
# More tool groups than thinking blocks - emit remaining
for block in server_tool_groups[grp_idx]:
item_id = (
block.get("id")
if isinstance(block, dict)
else getattr(block, "id", None)
)
if item_id and item_id in unique_tool_ids:
continue
if item_id:
unique_tool_ids.add(item_id)
assistant_content.append(
cast(AnthropicMessagesAssistantMessageValues, block)
)
grp_idx += 1
# Add text block (if any)
if text_element is not None:
assistant_content.append(text_element)
# Add regular (non-server) tool calls at the end
for item in regular_tool_uses:
item_id = (
item.get("id")
if isinstance(item, dict)
else getattr(item, "id", None)
)
if item_id and item_id in unique_tool_ids:
continue
if item_id:
unique_tool_ids.add(item_id)
assistant_content.append(
cast(AnthropicMessagesAssistantMessageValues, item)
)
# Mark tool_calls as already processed so they are not added again
assistant_tool_calls = None
else:
# SEQUENTIAL MODE: No server tool calls, or no thinking blocks,
# or content is a list. Use the original sequential approach.
# When content is a list, check if it already contains thinking
# blocks inline. If so, skip prepending thinking_blocks to avoid
# duplication and preserve the original interleaved order.
# Fixes the gap where list-content messages bypass INTERLEAVED
# MODE and still get thinking blocks prepended out of order.
_content_is_list = "content" in assistant_content_block and isinstance(
assistant_content_block["content"], list
)
_list_has_thinking = False
if _content_is_list:
for _item in assistant_content_block["content"]:
if isinstance(_item, dict) and _item.get("type") in ("thinking", "redacted_thinking"):
_list_has_thinking = True
break
if "cache_control" in _content_element:
_anthropic_text_content_element["cache_control"] = _content_element[
"cache_control"
]
if (
thinking_blocks is not None
and not _list_has_thinking
): # IMPORTANT: ADD THIS FIRST, ELSE ANTHROPIC WILL RAISE AN ERROR
assistant_content.extend(thinking_blocks)
if _content_is_list:
for m in assistant_content_block["content"]:
# handle thinking blocks
thinking_block = cast(str, m.get("thinking", ""))
text_block = cast(str, m.get("text", ""))
if (
m.get("type", "") == "thinking" and len(thinking_block) > 0
): # don't pass empty text blocks. anthropic api raises errors.
anthropic_message: Union[
ChatCompletionThinkingBlock,
AnthropicMessagesTextParam,
] = cast(ChatCompletionThinkingBlock, m)
assistant_content.append(anthropic_message)
# handle text
elif (
m.get("type", "") == "text" and len(text_block) > 0
): # don't pass empty text blocks. anthropic api raises errors.
anthropic_message = AnthropicMessagesTextParam(
type="text", text=text_block
)
_cached_message = add_cache_control_to_content(
anthropic_content_element=anthropic_message,
original_content_element=dict(m),
)
assistant_content.append(_anthropic_text_content_element)
assistant_content.append(
cast(AnthropicMessagesTextParam, _cached_message)
)
# handle server_tool_use blocks (tool search, web search, etc.)
# Pass through as-is since these are Anthropic-native content types
elif m.get("type", "") == "server_tool_use":
assistant_content.append(m) # type: ignore
# handle all *_tool_result blocks (tool_search_tool_result,
# web_search_tool_result, bash_code_execution_tool_result, etc.)
# Pass through as-is since these are Anthropic-native content types
elif m.get("type", "").endswith("_tool_result"):
assistant_content.append(m) # type: ignore
elif (
"content" in assistant_content_block
and isinstance(assistant_content_block["content"], str)
and assistant_content_block[
"content"
] # don't pass empty text blocks. anthropic api raises errors.
):
_anthropic_text_content_element = AnthropicMessagesTextParam(
type="text",
text=assistant_content_block["content"],
)
_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
original_content_element=dict(assistant_content_block),
)
if "cache_control" in _content_element:
_anthropic_text_content_element["cache_control"] = _content_element[
"cache_control"
]
assistant_content.append(_anthropic_text_content_element)
assistant_tool_calls = assistant_content_block.get("tool_calls")
if (
assistant_tool_calls is not None
): # support assistant tool invoke conversion
@@ -75,53 +75,6 @@ def _redact_responses_api_output(output_items):
summary_item.text = "redacted-by-litellm"
def _redact_standard_logging_object(model_call_details: dict):
"""Redact messages and response inside standard_logging_object if present."""
standard_logging_object = model_call_details.get("standard_logging_object")
if standard_logging_object is None:
return
redacted_str = "redacted-by-litellm"
if standard_logging_object.get("messages") is not None:
standard_logging_object["messages"] = [
{"role": "user", "content": redacted_str}
]
response = standard_logging_object.get("response")
if response is not None:
if isinstance(response, dict) and "output" in response:
# ResponsesAPIResponse format - redact content in output items
if isinstance(response.get("output"), list):
for output_item in response["output"]:
if isinstance(output_item, dict) and "content" in output_item:
if isinstance(output_item["content"], list):
for content_item in output_item["content"]:
if (
isinstance(content_item, dict)
and "text" in content_item
):
content_item["text"] = redacted_str
elif isinstance(response, dict) and "choices" in response:
# ModelResponse dict format - redact content in choices
if isinstance(response.get("choices"), list):
for choice in response["choices"]:
if isinstance(choice, dict):
if "message" in choice and isinstance(choice["message"], dict):
choice["message"]["content"] = redacted_str
if "audio" in choice["message"]:
choice["message"]["audio"] = None
elif "delta" in choice and isinstance(choice["delta"], dict):
choice["delta"]["content"] = redacted_str
if "audio" in choice["delta"]:
choice["delta"]["audio"] = None
elif isinstance(response, str):
standard_logging_object["response"] = redacted_str
else:
# For other formats (empty dict, None, etc.), use simple text format
standard_logging_object["response"] = {"text": redacted_str}
def perform_redaction(model_call_details: dict, result):
"""
Performs the actual redaction on the logging object and result.
@@ -4,10 +4,7 @@ from typing import List
import litellm
from litellm.exceptions import UnsupportedParamsError
from litellm.llms.openai.chat.gpt_5_transformation import (
OpenAIGPT5Config,
_get_effort_level,
)
from litellm.llms.openai.chat.gpt_5_transformation import OpenAIGPT5Config
from litellm.types.llms.openai import AllMessageValues
from .gpt_transformation import AzureOpenAIConfig
@@ -84,27 +81,24 @@ class AzureOpenAIGPT5Config(AzureOpenAIConfig, OpenAIGPT5Config):
drop_params: bool,
api_version: str = "",
) -> dict:
reasoning_effort_value = non_default_params.get(
"reasoning_effort"
) or optional_params.get("reasoning_effort")
effective_effort = _get_effort_level(reasoning_effort_value)
reasoning_effort_value = (
non_default_params.get("reasoning_effort")
or optional_params.get("reasoning_effort")
)
# gpt-5.1/5.2/5.4 support reasoning_effort='none', but other gpt-5 models don't
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/reasoning
supports_none = self._supports_reasoning_effort_level(model, "none")
if effective_effort == "none" and not supports_none:
if reasoning_effort_value == "none" and not supports_none:
if litellm.drop_params is True or (
drop_params is not None and drop_params is True
):
non_default_params = non_default_params.copy()
optional_params = optional_params.copy()
if (
_get_effort_level(non_default_params.get("reasoning_effort"))
== "none"
):
if non_default_params.get("reasoning_effort") == "none":
non_default_params.pop("reasoning_effort")
if _get_effort_level(optional_params.get("reasoning_effort")) == "none":
if optional_params.get("reasoning_effort") == "none":
optional_params.pop("reasoning_effort")
else:
raise UnsupportedParamsError(
@@ -127,19 +121,9 @@ class AzureOpenAIGPT5Config(AzureOpenAIConfig, OpenAIGPT5Config):
)
# Only drop reasoning_effort='none' for models that don't support it
result_effort = _get_effort_level(result.get("reasoning_effort"))
if result_effort == "none" and not supports_none:
if result.get("reasoning_effort") == "none" and not supports_none:
result.pop("reasoning_effort")
# Azure Chat Completions: gpt-5.4+ does not support tools + reasoning together.
# Drop reasoning_effort when both are present (OpenAI routes to Responses API; Azure does not).
if self.is_model_gpt_5_4_plus_model(model):
has_tools = bool(
non_default_params.get("tools") or optional_params.get("tools")
)
if has_tools and result_effort not in (None, "none"):
result.pop("reasoning_effort", None)
return result
def transform_request(
@@ -51,7 +51,6 @@ from litellm.types.llms.openai import (
)
from litellm.types.utils import (
ChatCompletionMessageToolCall,
CompletionTokensDetailsWrapper,
Function,
Message,
ModelResponse,
@@ -64,7 +63,6 @@ from litellm.utils import (
has_tool_call_blocks,
last_assistant_with_tool_calls_has_no_thinking_blocks,
supports_reasoning,
token_counter,
)
from ..common_utils import (
@@ -1641,11 +1639,7 @@ class AmazonConverseConfig(BaseConfig):
thinking_blocks_list.append(_redacted_block)
return thinking_blocks_list
def _transform_usage(
self,
usage: ConverseTokenUsageBlock,
reasoning_content: Optional[str] = None,
) -> Usage:
def _transform_usage(self, usage: ConverseTokenUsageBlock) -> Usage:
input_tokens = usage["inputTokens"]
output_tokens = usage["outputTokens"]
total_tokens = usage["totalTokens"]
@@ -1662,19 +1656,6 @@ class AmazonConverseConfig(BaseConfig):
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
reasoning_tokens = (
token_counter(text=reasoning_content, count_response_tokens=True)
if reasoning_content
else 0
)
completion_tokens_details = CompletionTokensDetailsWrapper(
reasoning_tokens=reasoning_tokens,
text_tokens=(
output_tokens - reasoning_tokens
if reasoning_tokens > 0
else output_tokens
),
)
openai_usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
@@ -1682,7 +1663,6 @@ class AmazonConverseConfig(BaseConfig):
prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
completion_tokens_details=completion_tokens_details,
)
return openai_usage
@@ -2019,10 +1999,7 @@ class AmazonConverseConfig(BaseConfig):
chat_completion_message["tool_calls"] = filtered_tools
## CALCULATING USAGE - bedrock returns usage in the headers
usage = self._transform_usage(
completion_response["usage"],
reasoning_content=chat_completion_message.get("reasoning_content"),
)
usage = self._transform_usage(completion_response["usage"])
## HANDLE TOOL CALLS
_message = Message(**chat_completion_message)
@@ -0,0 +1,21 @@
from .common_utils import (
DEFAULT_API_BASE,
DEFAULT_MAX_POLLING_TIME,
DEFAULT_POLLING_INTERVAL,
IMAGE_EDIT_MODELS,
IMAGE_GENERATION_MODELS,
BlackForestLabsError,
)
from .image_edit import BlackForestLabsImageEditConfig
from .image_generation import BlackForestLabsImageGenerationConfig
__all__ = [
"BlackForestLabsError",
"BlackForestLabsImageEditConfig",
"BlackForestLabsImageGenerationConfig",
"DEFAULT_API_BASE",
"DEFAULT_MAX_POLLING_TIME",
"DEFAULT_POLLING_INTERVAL",
"IMAGE_EDIT_MODELS",
"IMAGE_GENERATION_MODELS",
]
@@ -0,0 +1,42 @@
"""
Black Forest Labs Common Utilities
Common utilities, constants, and error handling for Black Forest Labs API.
"""
from typing import Dict
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class BlackForestLabsError(BaseLLMException):
"""Exception class for Black Forest Labs API errors."""
pass
# API Constants
DEFAULT_API_BASE = "https://api.bfl.ai"
# Polling configuration
DEFAULT_POLLING_INTERVAL = 1.5 # seconds
DEFAULT_MAX_POLLING_TIME = 300 # 5 minutes
# Model to endpoint mapping for image edit
IMAGE_EDIT_MODELS: Dict[str, str] = {
"flux-kontext-pro": "/v1/flux-kontext-pro",
"flux-kontext-max": "/v1/flux-kontext-max",
"flux-pro-1.0-fill": "/v1/flux-pro-1.0-fill",
"flux-pro-1.0-expand": "/v1/flux-pro-1.0-expand",
}
# Model to endpoint mapping for image generation
IMAGE_GENERATION_MODELS: Dict[str, str] = {
"flux-pro-1.1": "/v1/flux-pro-1.1",
"flux-pro-1.1-ultra": "/v1/flux-pro-1.1-ultra",
"flux-dev": "/v1/flux-dev",
"flux-pro": "/v1/flux-pro",
# Kontext models support both text-to-image and image editing
"flux-kontext-pro": "/v1/flux-kontext-pro",
"flux-kontext-max": "/v1/flux-kontext-max",
}
@@ -0,0 +1,8 @@
from .handler import BlackForestLabsImageEdit, bfl_image_edit
from .transformation import BlackForestLabsImageEditConfig
__all__ = [
"BlackForestLabsImageEditConfig",
"BlackForestLabsImageEdit",
"bfl_image_edit",
]
@@ -0,0 +1,454 @@
"""
Black Forest Labs Image Edit Handler
Handles image edit requests for Black Forest Labs models.
BFL uses an async polling pattern - the initial request returns a task ID,
then we poll until the result is ready.
"""
import asyncio
import time
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import FileTypes, ImageResponse
from ..common_utils import (
DEFAULT_MAX_POLLING_TIME,
DEFAULT_POLLING_INTERVAL,
BlackForestLabsError,
)
from .transformation import BlackForestLabsImageEditConfig
class BlackForestLabsImageEdit:
"""
Black Forest Labs Image Edit handler.
Handles the HTTP requests and polling logic, delegating data transformation
to the BlackForestLabsImageEditConfig class.
"""
def __init__(self):
self.config = BlackForestLabsImageEditConfig()
def image_edit(
self,
model: str,
image: Union[FileTypes, List[FileTypes]],
prompt: Optional[str],
image_edit_optional_request_params: Dict,
litellm_params: Union[GenericLiteLLMParams, Dict],
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[Dict[str, Any]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
aimage_edit: bool = False,
) -> Union[ImageResponse, Any]:
"""
Main entry point for image edit requests.
Args:
model: The model to use (e.g., "black_forest_labs/flux-kontext-pro")
image: The image(s) to edit
prompt: The edit instruction
image_edit_optional_request_params: Optional parameters for the request
litellm_params: LiteLLM parameters including api_key, api_base
logging_obj: Logging object
timeout: Request timeout
extra_headers: Additional headers
client: HTTP client to use
aimage_edit: If True, return async coroutine
Returns:
ImageResponse or coroutine if aimage_edit=True
"""
# Handle litellm_params as dict or object
if isinstance(litellm_params, dict):
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
litellm_params_dict = litellm_params
else:
api_key = litellm_params.api_key
api_base = litellm_params.api_base
litellm_params_dict = dict(litellm_params)
if aimage_edit:
return self.async_image_edit(
model=model,
image=image,
prompt=prompt,
image_edit_optional_request_params=image_edit_optional_request_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
timeout=timeout,
extra_headers=extra_headers,
client=client if isinstance(client, AsyncHTTPHandler) else None,
)
# Sync version
if client is None or not isinstance(client, HTTPHandler):
sync_client = _get_httpx_client()
else:
sync_client = client
# Validate environment and get headers
headers = self.config.validate_environment(
api_key=api_key,
headers=image_edit_optional_request_params.get("extra_headers", {}) or {},
model=model,
)
if extra_headers:
headers.update(extra_headers)
# Get complete URL
complete_url = self.config.get_complete_url(
model=model,
api_base=api_base,
litellm_params=litellm_params_dict,
)
# Transform request
# Handle image list vs single image
if isinstance(image, list):
if not image:
raise BlackForestLabsError(status_code=400, message="No image provided")
image_input = image[0]
else:
image_input = image
data, _ = self.config.transform_image_edit_request(
model=model,
prompt=prompt or "",
image=image_input,
image_edit_optional_request_params=image_edit_optional_request_params,
litellm_params=litellm_params_dict,
headers=headers,
)
# Logging
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": complete_url,
"headers": headers,
},
)
# Make initial request
try:
response = sync_client.post(
url=complete_url,
headers=headers,
json=data,
timeout=timeout,
)
except Exception as e:
raise BlackForestLabsError(
status_code=500,
message=f"Request failed: {str(e)}",
)
# Poll for result
final_response = self._poll_for_result_sync(
initial_response=response,
headers=headers,
sync_client=sync_client,
)
# Transform response
return self.config.transform_image_edit_response(
model=model,
raw_response=final_response,
logging_obj=logging_obj,
)
async def async_image_edit(
self,
model: str,
image: Union[FileTypes, List[FileTypes]],
prompt: Optional[str],
image_edit_optional_request_params: Dict,
litellm_params: Union[GenericLiteLLMParams, Dict],
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[Dict[str, Any]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ImageResponse:
"""
Async version of image edit.
"""
# Handle litellm_params as dict or object
if isinstance(litellm_params, dict):
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
litellm_params_dict = litellm_params
else:
api_key = litellm_params.api_key
api_base = litellm_params.api_base
litellm_params_dict = dict(litellm_params)
if client is None:
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BLACK_FOREST_LABS,
)
else:
async_client = client
# Validate environment and get headers
headers = self.config.validate_environment(
api_key=api_key,
headers=image_edit_optional_request_params.get("extra_headers", {}) or {},
model=model,
)
if extra_headers:
headers.update(extra_headers)
# Get complete URL
complete_url = self.config.get_complete_url(
model=model,
api_base=api_base,
litellm_params=litellm_params_dict,
)
# Transform request
if isinstance(image, list):
if not image:
raise BlackForestLabsError(status_code=400, message="No image provided")
image_input = image[0]
else:
image_input = image
data, _ = self.config.transform_image_edit_request(
model=model,
prompt=prompt or "",
image=image_input,
image_edit_optional_request_params=image_edit_optional_request_params,
litellm_params=litellm_params_dict,
headers=headers,
)
# Logging
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": complete_url,
"headers": headers,
},
)
# Make initial request
try:
response = await async_client.post(
url=complete_url,
headers=headers,
json=data,
timeout=timeout,
)
except Exception as e:
raise BlackForestLabsError(
status_code=500,
message=f"Request failed: {str(e)}",
)
# Poll for result
final_response = await self._poll_for_result_async(
initial_response=response,
headers=headers,
async_client=async_client,
)
# Transform response
return self.config.transform_image_edit_response(
model=model,
raw_response=final_response,
logging_obj=logging_obj,
)
def _poll_for_result_sync(
self,
initial_response: httpx.Response,
headers: dict,
sync_client: HTTPHandler,
max_wait: float = DEFAULT_MAX_POLLING_TIME,
interval: float = DEFAULT_POLLING_INTERVAL,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> httpx.Response:
"""
Poll BFL API until result is ready (sync version).
Args:
initial_response: The initial response containing polling_url
headers: Headers to use for polling (must include x-key)
sync_client: HTTP client
max_wait: Maximum time to wait in seconds
interval: Polling interval in seconds
timeout: Timeout for each individual polling request
Returns:
Final response with completed result
"""
# Validate initial response status code
if initial_response.status_code >= 400:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL initial request failed: {initial_response.text}",
)
# Parse initial response to get polling URL
try:
response_data = initial_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"Error parsing initial response: {e}",
)
# Check for immediate errors
if "errors" in response_data:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL error: {response_data['errors']}",
)
polling_url = response_data.get("polling_url")
if not polling_url:
raise BlackForestLabsError(
status_code=500,
message="No polling_url in BFL response",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
start_time = time.time()
verbose_logger.debug(f"BFL starting sync polling at {polling_url}")
while time.time() - start_time < max_wait:
response = sync_client.get(
url=polling_url,
headers=polling_headers,
)
if response.status_code != 200:
raise BlackForestLabsError(
status_code=response.status_code,
message=f"Polling failed: {response.text}",
)
data = response.json()
status = data.get("status")
verbose_logger.debug(f"BFL poll status: {status}")
if status == "Ready":
return response
elif status in ["Error", "Failed", "Content Moderated", "Request Moderated"]:
raise BlackForestLabsError(
status_code=400,
message=f"Image generation failed: {status}",
)
time.sleep(interval)
raise BlackForestLabsError(
status_code=408,
message=f"Polling timed out after {max_wait} seconds",
)
async def _poll_for_result_async(
self,
initial_response: httpx.Response,
headers: dict,
async_client: AsyncHTTPHandler,
max_wait: float = DEFAULT_MAX_POLLING_TIME,
interval: float = DEFAULT_POLLING_INTERVAL,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> httpx.Response:
"""
Poll BFL API until result is ready (async version).
"""
# Validate initial response status code
if initial_response.status_code >= 400:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL initial request failed: {initial_response.text}",
)
# Parse initial response to get polling URL
try:
response_data = initial_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"Error parsing initial response: {e}",
)
# Check for immediate errors
if "errors" in response_data:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL error: {response_data['errors']}",
)
polling_url = response_data.get("polling_url")
if not polling_url:
raise BlackForestLabsError(
status_code=500,
message="No polling_url in BFL response",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
start_time = time.time()
verbose_logger.debug(f"BFL starting async polling at {polling_url}")
while time.time() - start_time < max_wait:
response = await async_client.get(
url=polling_url,
headers=polling_headers,
)
if response.status_code != 200:
raise BlackForestLabsError(
status_code=response.status_code,
message=f"Polling failed: {response.text}",
)
data = response.json()
status = data.get("status")
verbose_logger.debug(f"BFL poll status: {status}")
if status == "Ready":
return response
elif status in ["Error", "Failed", "Content Moderated", "Request Moderated"]:
raise BlackForestLabsError(
status_code=400,
message=f"Image generation failed: {status}",
)
await asyncio.sleep(interval)
raise BlackForestLabsError(
status_code=408,
message=f"Polling timed out after {max_wait} seconds",
)
# Singleton instance for use in images/main.py
bfl_image_edit = BlackForestLabsImageEdit()
@@ -0,0 +1,308 @@
"""
Black Forest Labs Image Edit Configuration
Handles transformation between OpenAI-compatible format and Black Forest Labs API format
for image editing endpoints (flux-kontext-pro, flux-kontext-max, etc.).
API Reference: https://docs.bfl.ai/
"""
import base64
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from httpx._types import RequestFiles
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.images.main import ImageEditOptionalRequestParams
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import FileTypes, ImageObject, ImageResponse
from ..common_utils import (
DEFAULT_API_BASE,
IMAGE_EDIT_MODELS,
BlackForestLabsError,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BlackForestLabsImageEditConfig(BaseImageEditConfig):
"""
Configuration for Black Forest Labs image editing.
Supports:
- flux-kontext-pro: General image editing with prompts
- flux-kontext-max: Premium quality editing
- flux-pro-1.0-fill: Inpainting with mask
- flux-pro-1.0-expand: Outpainting (expand image borders)
Note: HTTP requests and polling are handled by the handler (handler.py).
This class only handles data transformation.
"""
def get_supported_openai_params(self, model: str) -> List[str]:
"""
Return list of OpenAI params supported by Black Forest Labs.
Note: BFL uses different parameter names, these are mapped in map_openai_params.
"""
return [
"mask",
"seed",
"output_format",
"safety_tolerance",
"prompt_upsampling",
"aspect_ratio",
"steps",
"guidance",
"grow_mask",
"top",
"bottom",
"left",
"right",
]
def map_openai_params(
self,
image_edit_optional_params: ImageEditOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""
Map OpenAI parameters to Black Forest Labs parameters.
BFL-specific params are passed through directly.
"""
optional_params: Dict[str, Any] = {}
# Pass through BFL-specific params
bfl_params = [
"seed",
"output_format",
"safety_tolerance",
"prompt_upsampling",
# Kontext-specific
"aspect_ratio",
# Fill/Inpaint-specific
"steps",
"guidance",
"grow_mask",
# Expand-specific
"top",
"bottom",
"left",
"right",
]
# Convert TypedDict to regular dict for access
params_dict = dict(image_edit_optional_params)
for param in bfl_params:
if param in params_dict:
value = params_dict[param]
if value is not None:
optional_params[param] = value
# Set default output format
if "output_format" not in optional_params:
optional_params["output_format"] = "png"
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
"""
Validate environment and set up headers for Black Forest Labs.
BFL uses x-key header for authentication.
"""
final_api_key: Optional[str] = (
api_key
or get_secret_str("BFL_API_KEY")
or get_secret_str("BLACK_FOREST_LABS_API_KEY")
)
if not final_api_key:
raise BlackForestLabsError(
status_code=401,
message="BFL_API_KEY is not set. Please set it via environment variable or pass api_key parameter.",
)
headers["x-key"] = final_api_key
headers["Content-Type"] = "application/json"
headers["Accept"] = "application/json"
return headers
def use_multipart_form_data(self) -> bool:
"""
BFL uses JSON requests, not multipart/form-data.
"""
return False
def _get_model_endpoint(self, model: str) -> str:
"""
Get the API endpoint for a given model.
"""
# Remove provider prefix if present (e.g., "black_forest_labs/flux-kontext-pro")
model_name = model.lower()
if "/" in model_name:
model_name = model_name.split("/")[-1]
# Check if model is in our mapping
if model_name in IMAGE_EDIT_MODELS:
return IMAGE_EDIT_MODELS[model_name]
raise ValueError(
f"Unknown BFL image edit model: {model_name}. "
f"Supported models: {list(IMAGE_EDIT_MODELS.keys())}"
)
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete URL for the Black Forest Labs API request.
"""
base_url: str = (
api_base
or get_secret_str("BFL_API_BASE")
or DEFAULT_API_BASE
)
base_url = base_url.rstrip("/")
endpoint = self._get_model_endpoint(model)
return f"{base_url}{endpoint}"
def _read_image_bytes(self, image: Any) -> bytes:
"""Read image bytes from various input types."""
if isinstance(image, bytes):
return image
elif isinstance(image, list):
# If it's a list, take the first image
return self._read_image_bytes(image[0])
elif isinstance(image, str):
if image.startswith(("http://", "https://")):
# Download image from URL
response = httpx.get(image, timeout=60.0)
response.raise_for_status()
return response.content
else:
# Assume it's a file path
with open(image, "rb") as f:
return f.read()
elif hasattr(image, "read"):
# File-like object
pos = getattr(image, "tell", lambda: 0)()
if hasattr(image, "seek"):
image.seek(0)
data = image.read()
if hasattr(image, "seek"):
image.seek(pos)
return data
else:
raise ValueError(
f"Unsupported image type: {type(image)}. "
"Expected bytes, str (URL or file path), or file-like object."
)
def transform_image_edit_request(
self,
model: str,
prompt: str,
image: FileTypes,
image_edit_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict, RequestFiles]:
"""
Transform OpenAI-style request to Black Forest Labs request format.
BFL uses JSON body with base64-encoded images, not multipart/form-data.
"""
# Read and encode image
image_bytes = self._read_image_bytes(image)
b64_image = base64.b64encode(image_bytes).decode("utf-8")
# Build request body
request_body: Dict[str, Any] = {
"prompt": prompt,
"input_image": b64_image,
}
# Add optional params (only BFL-recognized parameters)
bfl_request_params = [
"seed", "output_format", "safety_tolerance", "prompt_upsampling",
"aspect_ratio", "steps", "guidance", "grow_mask",
"top", "bottom", "left", "right",
]
for key, value in image_edit_optional_request_params.items():
if key in bfl_request_params and value is not None:
request_body[key] = value
# Handle mask if provided (for inpainting)
if "mask" in image_edit_optional_request_params:
mask = image_edit_optional_request_params["mask"]
mask_bytes = self._read_image_bytes(mask)
request_body["mask"] = base64.b64encode(mask_bytes).decode("utf-8")
# BFL uses JSON, not multipart - return empty files
return request_body, []
def transform_image_edit_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ImageResponse:
"""
Transform Black Forest Labs response to OpenAI-compatible ImageResponse.
This is called with the FINAL polled response (after handler does polling).
The response contains: {"status": "Ready", "result": {"sample": "https://..."}}
"""
try:
response_data = raw_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=raw_response.status_code,
message=f"Error parsing BFL response: {e}",
)
# Get image URL from result
image_url = response_data.get("result", {}).get("sample")
if not image_url:
raise BlackForestLabsError(
status_code=500,
message="No image URL in BFL result",
)
# Build ImageResponse
return ImageResponse(
created=int(time.time()),
data=[ImageObject(url=image_url)],
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BlackForestLabsError:
"""Return the appropriate error class for Black Forest Labs."""
return BlackForestLabsError(
status_code=status_code,
message=error_message,
)
@@ -0,0 +1,12 @@
from .handler import BlackForestLabsImageGeneration, bfl_image_generation
from .transformation import (
BlackForestLabsImageGenerationConfig,
get_black_forest_labs_image_generation_config,
)
__all__ = [
"BlackForestLabsImageGenerationConfig",
"get_black_forest_labs_image_generation_config",
"BlackForestLabsImageGeneration",
"bfl_image_generation",
]
@@ -0,0 +1,440 @@
"""
Black Forest Labs Image Generation Handler
Handles image generation requests for Black Forest Labs models.
BFL uses an async polling pattern - the initial request returns a task ID,
then we poll until the result is ready.
"""
import asyncio
import time
from typing import Any, Dict, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import ImageResponse
from ..common_utils import (
DEFAULT_MAX_POLLING_TIME,
DEFAULT_POLLING_INTERVAL,
BlackForestLabsError,
)
from .transformation import BlackForestLabsImageGenerationConfig
class BlackForestLabsImageGeneration:
"""
Black Forest Labs Image Generation handler.
Handles the HTTP requests and polling logic, delegating data transformation
to the BlackForestLabsImageGenerationConfig class.
"""
def __init__(self):
self.config = BlackForestLabsImageGenerationConfig()
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: Dict,
litellm_params: Union[GenericLiteLLMParams, Dict],
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[Dict[str, Any]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
aimg_generation: bool = False,
) -> Union[ImageResponse, Any]:
"""
Main entry point for image generation requests.
Args:
model: The model to use (e.g., "black_forest_labs/flux-pro-1.1")
prompt: The text prompt for image generation
model_response: ImageResponse object to populate
optional_params: Optional parameters for the request
litellm_params: LiteLLM parameters including api_key, api_base
logging_obj: Logging object
timeout: Request timeout
extra_headers: Additional headers
client: HTTP client to use
aimg_generation: If True, return async coroutine
Returns:
ImageResponse or coroutine if aimg_generation=True
"""
# Handle litellm_params as dict or object
if isinstance(litellm_params, dict):
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
litellm_params_dict = litellm_params
else:
api_key = litellm_params.api_key
api_base = litellm_params.api_base
litellm_params_dict = dict(litellm_params)
if aimg_generation:
return self.async_image_generation(
model=model,
prompt=prompt,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
timeout=timeout,
extra_headers=extra_headers,
client=client if isinstance(client, AsyncHTTPHandler) else None,
)
# Sync version
if client is None or not isinstance(client, HTTPHandler):
sync_client = _get_httpx_client()
else:
sync_client = client
# Validate environment and get headers
headers = self.config.validate_environment(
api_key=api_key,
headers={},
model=model,
messages=[],
optional_params=optional_params,
litellm_params=litellm_params_dict,
)
if extra_headers:
headers.update(extra_headers)
# Get complete URL
complete_url = self.config.get_complete_url(
api_base=api_base,
api_key=api_key,
model=model,
optional_params=optional_params,
litellm_params=litellm_params_dict,
)
# Transform request
data = self.config.transform_image_generation_request(
model=model,
prompt=prompt,
optional_params=optional_params,
litellm_params=litellm_params_dict,
headers=headers,
)
# Logging
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": complete_url,
"headers": headers,
},
)
# Make initial request
try:
response = sync_client.post(
url=complete_url,
headers=headers,
json=data,
timeout=timeout,
)
except Exception as e:
raise BlackForestLabsError(
status_code=500,
message=f"Request failed: {str(e)}",
)
# Poll for result
final_response = self._poll_for_result_sync(
initial_response=response,
headers=headers,
sync_client=sync_client,
)
# Transform response
return self.config.transform_image_generation_response(
model=model,
raw_response=final_response,
model_response=model_response,
logging_obj=logging_obj,
)
async def async_image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: Dict,
litellm_params: Union[GenericLiteLLMParams, Dict],
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[Dict[str, Any]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ImageResponse:
"""
Async version of image generation.
"""
# Handle litellm_params as dict or object
if isinstance(litellm_params, dict):
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
litellm_params_dict = litellm_params
else:
api_key = litellm_params.api_key
api_base = litellm_params.api_base
litellm_params_dict = dict(litellm_params)
if client is None:
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BLACK_FOREST_LABS,
)
else:
async_client = client
# Validate environment and get headers
headers = self.config.validate_environment(
api_key=api_key,
headers={},
model=model,
messages=[],
optional_params=optional_params,
litellm_params=litellm_params_dict,
)
if extra_headers:
headers.update(extra_headers)
# Get complete URL
complete_url = self.config.get_complete_url(
api_base=api_base,
api_key=api_key,
model=model,
optional_params=optional_params,
litellm_params=litellm_params_dict,
)
# Transform request
data = self.config.transform_image_generation_request(
model=model,
prompt=prompt,
optional_params=optional_params,
litellm_params=litellm_params_dict,
headers=headers,
)
# Logging
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": complete_url,
"headers": headers,
},
)
# Make initial request
try:
response = await async_client.post(
url=complete_url,
headers=headers,
json=data,
timeout=timeout,
)
except Exception as e:
raise BlackForestLabsError(
status_code=500,
message=f"Request failed: {str(e)}",
)
# Poll for result
final_response = await self._poll_for_result_async(
initial_response=response,
headers=headers,
async_client=async_client,
)
# Transform response
return self.config.transform_image_generation_response(
model=model,
raw_response=final_response,
model_response=model_response,
logging_obj=logging_obj,
)
def _poll_for_result_sync(
self,
initial_response: httpx.Response,
headers: dict,
sync_client: HTTPHandler,
max_wait: float = DEFAULT_MAX_POLLING_TIME,
interval: float = DEFAULT_POLLING_INTERVAL,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> httpx.Response:
"""
Poll BFL API until result is ready (sync version).
"""
# Validate initial response status code
if initial_response.status_code >= 400:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL initial request failed: {initial_response.text}",
)
# Parse initial response to get polling URL
try:
response_data = initial_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"Error parsing initial response: {e}",
)
# Check for immediate errors
if "errors" in response_data:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL error: {response_data['errors']}",
)
polling_url = response_data.get("polling_url")
if not polling_url:
raise BlackForestLabsError(
status_code=500,
message="No polling_url in BFL response",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
start_time = time.time()
verbose_logger.debug(f"BFL starting sync polling at {polling_url}")
while time.time() - start_time < max_wait:
response = sync_client.get(
url=polling_url,
headers=polling_headers,
)
if response.status_code != 200:
raise BlackForestLabsError(
status_code=response.status_code,
message=f"Polling failed: {response.text}",
)
data = response.json()
status = data.get("status")
verbose_logger.debug(f"BFL poll status: {status}")
if status == "Ready":
return response
elif status in ["Error", "Failed", "Content Moderated", "Request Moderated"]:
raise BlackForestLabsError(
status_code=400,
message=f"Image generation failed: {status}",
)
time.sleep(interval)
raise BlackForestLabsError(
status_code=408,
message=f"Polling timed out after {max_wait} seconds",
)
async def _poll_for_result_async(
self,
initial_response: httpx.Response,
headers: dict,
async_client: AsyncHTTPHandler,
max_wait: float = DEFAULT_MAX_POLLING_TIME,
interval: float = DEFAULT_POLLING_INTERVAL,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> httpx.Response:
"""
Poll BFL API until result is ready (async version).
"""
# Validate initial response status code
if initial_response.status_code >= 400:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL initial request failed: {initial_response.text}",
)
# Parse initial response to get polling URL
try:
response_data = initial_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"Error parsing initial response: {e}",
)
# Check for immediate errors
if "errors" in response_data:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL error: {response_data['errors']}",
)
polling_url = response_data.get("polling_url")
if not polling_url:
raise BlackForestLabsError(
status_code=500,
message="No polling_url in BFL response",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
start_time = time.time()
verbose_logger.debug(f"BFL starting async polling at {polling_url}")
while time.time() - start_time < max_wait:
response = await async_client.get(
url=polling_url,
headers=polling_headers,
)
if response.status_code != 200:
raise BlackForestLabsError(
status_code=response.status_code,
message=f"Polling failed: {response.text}",
)
data = response.json()
status = data.get("status")
verbose_logger.debug(f"BFL poll status: {status}")
if status == "Ready":
return response
elif status in ["Error", "Failed", "Content Moderated", "Request Moderated"]:
raise BlackForestLabsError(
status_code=400,
message=f"Image generation failed: {status}",
)
await asyncio.sleep(interval)
raise BlackForestLabsError(
status_code=408,
message=f"Polling timed out after {max_wait} seconds",
)
# Singleton instance for use in images/main.py
bfl_image_generation = BlackForestLabsImageGeneration()
@@ -0,0 +1,324 @@
"""
Black Forest Labs Image Generation Configuration
Handles transformation between OpenAI-compatible format and Black Forest Labs API format
for image generation endpoints (flux-pro-1.1, flux-pro-1.1-ultra, flux-dev, flux-pro).
API Reference: https://docs.bfl.ai/
"""
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIImageGenerationOptionalParams,
)
from litellm.types.utils import ImageObject, ImageResponse
from ..common_utils import (
DEFAULT_API_BASE,
IMAGE_GENERATION_MODELS,
BlackForestLabsError,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BlackForestLabsImageGenerationConfig(BaseImageGenerationConfig):
"""
Configuration for Black Forest Labs image generation (text-to-image).
Supports:
- flux-pro-1.1: Fast & reliable standard generation
- flux-pro-1.1-ultra: Ultra high-resolution (up to 4MP)
- flux-dev: Development/open-source variant
- flux-pro: Original pro model
Note: HTTP requests and polling are handled by the handler (handler.py).
This class only handles data transformation.
"""
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageGenerationOptionalParams]:
"""
Return list of OpenAI params supported by Black Forest Labs.
Note: BFL uses different parameter names, these are mapped in map_openai_params.
"""
return [
"n", # Number of images (BFL returns 1 per request, but ultra supports up to 4)
"size", # Maps to width/height or aspect_ratio
"quality", # Maps to raw mode for ultra
"seed",
"output_format",
"safety_tolerance",
"prompt_upsampling",
"raw",
"num_images",
"image_url",
"image_prompt_strength",
"aspect_ratio",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI parameters to Black Forest Labs parameters.
BFL-specific params are passed through directly.
"""
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in optional_params:
continue
if k in supported_params:
# Map OpenAI 'size' to BFL width/height
if k == "size" and v:
self._map_size_param(v, optional_params)
elif k == "n":
if "ultra" in model.lower():
optional_params["num_images"] = v
# non-ultra: silently skip (n=1 is BFL default)
elif k == "quality":
if v == "hd" and "ultra" in model.lower():
optional_params["raw"] = True
# other quality values have no BFL mapping
else:
optional_params[k] = v
elif not drop_params:
raise ValueError(
f"Parameter {k} is not supported for model {model}. "
f"Supported parameters are {supported_params}. "
f"Set drop_params=True to drop unsupported parameters."
)
return optional_params
def _map_size_param(self, size: str, optional_params: dict) -> None:
"""Map OpenAI size parameter to BFL width/height."""
# Common size mappings
size_mapping = {
"1024x1024": (1024, 1024),
"1792x1024": (1792, 1024),
"1024x1792": (1024, 1792),
"512x512": (512, 512),
"256x256": (256, 256),
}
if size in size_mapping:
width, height = size_mapping[size]
optional_params["width"] = width
optional_params["height"] = height
elif "x" in size:
# Parse custom size
try:
width, height = map(int, size.lower().split("x"))
optional_params["width"] = width
optional_params["height"] = height
except ValueError:
raise ValueError(
f"Invalid size format: '{size}'. Expected format 'WIDTHxHEIGHT' (e.g., '1024x1024')."
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate environment and set up headers for Black Forest Labs.
BFL uses x-key header for authentication.
"""
final_api_key: Optional[str] = (
api_key
or get_secret_str("BFL_API_KEY")
or get_secret_str("BLACK_FOREST_LABS_API_KEY")
)
if not final_api_key:
raise BlackForestLabsError(
status_code=401,
message="BFL_API_KEY is not set. Please set it via environment variable or pass api_key parameter.",
)
headers["x-key"] = final_api_key
headers["Content-Type"] = "application/json"
headers["Accept"] = "application/json"
return headers
def _get_model_endpoint(self, model: str) -> str:
"""
Get the API endpoint for a given model.
"""
# Remove provider prefix if present (e.g., "black_forest_labs/flux-pro-1.1")
model_name = model.lower()
if "/" in model_name:
model_name = model_name.split("/")[-1]
# Check if model is in our mapping
if model_name in IMAGE_GENERATION_MODELS:
return IMAGE_GENERATION_MODELS[model_name]
raise ValueError(
f"Unknown BFL image generation model: {model_name}. "
f"Supported models: {list(IMAGE_GENERATION_MODELS.keys())}"
)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the Black Forest Labs API request.
"""
base_url: str = (
api_base or get_secret_str("BFL_API_BASE") or DEFAULT_API_BASE
)
base_url = base_url.rstrip("/")
endpoint = self._get_model_endpoint(model)
return f"{base_url}{endpoint}"
def transform_image_generation_request(
self,
model: str,
prompt: str,
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform OpenAI-style request to Black Forest Labs request format.
https://docs.bfl.ai/flux_models/flux_1_1_pro
"""
# Build request body with prompt
request_body: Dict[str, Any] = {
"prompt": prompt,
}
# BFL-specific params that can be passed through
bfl_params = [
"width",
"height",
"aspect_ratio",
"seed",
"output_format",
"safety_tolerance",
"prompt_upsampling",
# Ultra-specific
"raw",
"num_images",
"image_url",
"image_prompt_strength",
]
for param in bfl_params:
if param in optional_params and optional_params[param] is not None:
request_body[param] = optional_params[param]
# Set default output format if not specified
if "output_format" not in request_body:
request_body["output_format"] = "png"
return request_body
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
**kwargs,
) -> ImageResponse:
"""
Transform Black Forest Labs response to OpenAI-compatible ImageResponse.
This is called with the FINAL polled response (after handler does polling).
The response contains: {"status": "Ready", "result": {"sample": "https://..."}}
"""
try:
response_data = raw_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=raw_response.status_code,
message=f"Error parsing BFL response: {e}",
)
result = response_data.get("result", {})
if not model_response.data:
model_response.data = []
# Handle single image (sample) or multiple images
if isinstance(result, dict) and "sample" in result:
model_response.data.append(ImageObject(url=result["sample"]))
elif isinstance(result, list):
# Multiple images returned
for img in result:
if isinstance(img, str):
model_response.data.append(ImageObject(url=img))
elif isinstance(img, dict) and "url" in img:
model_response.data.append(ImageObject(url=img["url"]))
if not model_response.data:
raise BlackForestLabsError(
status_code=500,
message="No image URL in BFL result",
)
model_response.created = int(time.time())
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BlackForestLabsError:
"""Return the appropriate error class for Black Forest Labs."""
return BlackForestLabsError(
status_code=status_code,
message=error_message,
)
def get_black_forest_labs_image_generation_config(
model: str,
) -> BlackForestLabsImageGenerationConfig:
"""
Get the appropriate image generation config for a Black Forest Labs model.
Currently returns a single config class, but can be extended
for model-specific configurations if needed.
"""
return BlackForestLabsImageGenerationConfig()
@@ -429,11 +429,8 @@ class FireworksAIConfig(OpenAIGPTConfig):
"FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
)
base = api_base.rstrip("/")
if base.endswith("/v1"):
base = base[: -len("/v1")]
response = litellm.module_level_client.get(
url=f"{base}/v1/accounts/{account_id}/models",
url=f"{api_base}/v1/accounts/{account_id}/models",
headers={"Authorization": f"Bearer {api_key}"},
)
@@ -0,0 +1,152 @@
"""
Support for Mistral Voxtral audio transcription via ``/v1/audio/transcriptions``.
API reference: https://docs.mistral.ai/api/#tag/audio/operation/audio_transcriptions_v1_audio_transcriptions_post
"""
from typing import List, Optional, Union
import httpx
from litellm.litellm_core_utils.audio_utils.utils import process_audio_file
from litellm.llms.base_llm.audio_transcription.transformation import (
AudioTranscriptionRequestData,
BaseAudioTranscriptionConfig,
)
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIAudioTranscriptionOptionalParams,
)
from litellm.types.utils import FileTypes, TranscriptionResponse
class MistralAudioTranscriptionException(BaseLLMException):
pass
class MistralAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
return [
"language",
"temperature",
"timestamp_granularities",
"response_format",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return optional_params
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
api_base = (
"https://api.mistral.ai/v1"
if api_base is None
else api_base.rstrip("/")
)
return f"{api_base}/audio/transcriptions"
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return MistralAudioTranscriptionException(
message=error_message,
status_code=status_code,
headers=headers,
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = get_secret_str("MISTRAL_API_KEY")
default_headers = {
"Authorization": f"Bearer {api_key}",
"accept": "application/json",
}
default_headers.update(headers or {})
return default_headers
def transform_audio_transcription_request(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
) -> AudioTranscriptionRequestData:
processed_audio = process_audio_file(audio_file)
form_fields: dict = {
"model": model,
}
# OpenAI-compatible params
for key in self.get_supported_openai_params(model):
value = optional_params.get(key)
if value is not None:
form_fields[key] = value
# Mistral-specific params (e.g. diarize)
provider_specific_params = self.get_provider_specific_params(
model=model,
optional_params=optional_params,
openai_params=self.get_supported_openai_params(model),
)
for key, value in provider_specific_params.items():
form_fields[key] = str(value).lower() if isinstance(value, bool) else str(value)
files = {
"file": (
processed_audio.filename,
processed_audio.file_content,
processed_audio.content_type,
)
}
return AudioTranscriptionRequestData(data=form_fields, files=files)
def transform_audio_transcription_response(
self,
raw_response: httpx.Response,
) -> TranscriptionResponse:
try:
response_json = raw_response.json()
except Exception:
raise MistralAudioTranscriptionException(
message=raw_response.text,
status_code=raw_response.status_code,
headers=raw_response.headers,
)
text = response_json.get("text") or ""
response = TranscriptionResponse(text=text)
response._hidden_params = response_json
return response
@@ -25,22 +25,6 @@ def _normalize_reasoning_effort_for_chat_completion(
return None
def _get_effort_level(value: Union[str, dict, None]) -> Optional[str]:
"""Extract the effective effort level from reasoning_effort (string or dict).
Use this for guards that compare effort level (e.g. xhigh validation, "none" checks).
Ensures dict inputs like {"effort": "none", "summary": "detailed"} are correctly
treated as effort="none" for validation purposes.
"""
if value is None:
return None
if isinstance(value, str):
return value
if isinstance(value, dict) and "effort" in value:
return value["effort"]
return None
class OpenAIGPT5Config(OpenAIGPTConfig):
"""Configuration for gpt-5 models including GPT-5-Codex variants.
@@ -86,19 +70,6 @@ class OpenAIGPT5Config(OpenAIGPTConfig):
model_name = model.split("/")[-1]
return model_name.startswith("gpt-5.4")
@classmethod
def is_model_gpt_5_4_plus_model(cls, model: str) -> bool:
"""Check if the model is gpt-5.4 or newer (5.4, 5.5, 5.6, etc., including pro)."""
model_name = model.split("/")[-1]
if not model_name.startswith("gpt-5."):
return False
try:
version_str = model_name.replace("gpt-5.", "").split("-")[0]
major = version_str.split(".")[0]
return int(major) >= 4
except (ValueError, IndexError):
return False
@classmethod
def _supports_reasoning_effort_level(cls, model: str, level: str) -> bool:
"""Check if the model supports a specific reasoning_effort level.
@@ -179,35 +150,21 @@ class OpenAIGPT5Config(OpenAIGPTConfig):
drop_params=drop_params,
)
# Get raw reasoning_effort and effective effort level for all guards.
# Use effective_effort (extracted string) for xhigh validation, "none" checks, and
# tool/sampling guards — dict inputs like {"effort": "none", "summary": "detailed"}
# must be treated as effort="none" to avoid incorrect tool-drop or sampling errors.
raw_reasoning_effort = non_default_params.get(
"reasoning_effort"
) or optional_params.get("reasoning_effort")
effective_effort = _get_effort_level(raw_reasoning_effort)
# Normalize to string for Chat Completions API when dict has only "effort".
# Preserve full dict (e.g. {"effort": "high", "summary": "detailed"}) for Responses API.
if isinstance(raw_reasoning_effort, dict) and set(
raw_reasoning_effort.keys()
) <= {"effort"}:
normalized = _normalize_reasoning_effort_for_chat_completion(
raw_reasoning_effort
)
if normalized is not None:
if "reasoning_effort" in non_default_params:
non_default_params["reasoning_effort"] = normalized
if "reasoning_effort" in optional_params:
optional_params["reasoning_effort"] = normalized
reasoning_effort = (
# Normalize reasoning_effort: chat completion API expects a string, not a dict
# (e.g. {'effort': 'high', 'summary': 'detailed'} -> 'high')
raw_reasoning_effort = (
non_default_params.get("reasoning_effort")
or optional_params.get("reasoning_effort")
or raw_reasoning_effort
)
if effective_effort is not None and effective_effort == "xhigh":
normalized = _normalize_reasoning_effort_for_chat_completion(raw_reasoning_effort)
if raw_reasoning_effort is not None and normalized is not None:
if "reasoning_effort" in non_default_params:
non_default_params["reasoning_effort"] = normalized
if "reasoning_effort" in optional_params:
optional_params["reasoning_effort"] = normalized
reasoning_effort = normalized or raw_reasoning_effort
if reasoning_effort is not None and reasoning_effort == "xhigh":
if not self._supports_reasoning_effort_level(model, "xhigh"):
if litellm.drop_params or drop_params:
non_default_params.pop("reasoning_effort", None)
@@ -234,20 +191,17 @@ class OpenAIGPT5Config(OpenAIGPTConfig):
has_tools = bool(
non_default_params.get("tools") or optional_params.get("tools")
)
if has_tools and effective_effort not in (None, "none"):
# Check if this will be routed to Responses API
# If so, keep reasoning_effort; otherwise drop it for chat completions API
if not self.is_model_gpt_5_4_plus_model(model):
non_default_params.pop("reasoning_effort", None)
optional_params.pop("reasoning_effort", None)
reasoning_effort = None
if has_tools and reasoning_effort not in (None, "none"):
non_default_params.pop("reasoning_effort", None)
optional_params.pop("reasoning_effort", None)
reasoning_effort = None
# gpt-5.1/5.2 support logprobs, top_p, top_logprobs only when reasoning_effort="none"
supports_none = self._supports_reasoning_effort_level(model, "none")
if supports_none:
sampling_params = ["logprobs", "top_logprobs", "top_p"]
has_sampling = any(p in non_default_params for p in sampling_params)
if has_sampling and effective_effort not in (None, "none"):
if has_sampling and reasoning_effort not in (None, "none"):
if litellm.drop_params or drop_params:
for p in sampling_params:
non_default_params.pop(p, None)
@@ -257,7 +211,7 @@ class OpenAIGPT5Config(OpenAIGPTConfig):
"gpt-5.1/5.2/5.4 only support logprobs, top_p, top_logprobs when "
"reasoning_effort='none'. Current reasoning_effort='{}'. "
"To drop unsupported params set `litellm.drop_params = True`"
).format(effective_effort),
).format(reasoning_effort),
status_code=400,
)
@@ -265,9 +219,7 @@ class OpenAIGPT5Config(OpenAIGPTConfig):
temperature_value: Optional[float] = non_default_params.pop("temperature")
if temperature_value is not None:
# models supporting reasoning_effort="none" also support flexible temperature
if supports_none and (
effective_effort == "none" or effective_effort is None
):
if supports_none and (reasoning_effort == "none" or reasoning_effort is None):
optional_params["temperature"] = temperature_value
elif temperature_value == 1:
optional_params["temperature"] = temperature_value
@@ -58,6 +58,7 @@ from ..common_utils import OpenAIError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.llms.base_llm.base_utils import BaseTokenCounter
from litellm.types.llms.openai import ChatCompletionToolParam
LiteLLMLoggingObj = _LiteLLMLoggingObj
@@ -759,6 +760,13 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
def get_base_model(model: Optional[str] = None) -> Optional[str]:
return model
def get_token_counter(self) -> Optional["BaseTokenCounter"]:
from litellm.llms.openai.responses.count_tokens.token_counter import (
OpenAITokenCounter,
)
return OpenAITokenCounter()
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
@@ -40,6 +40,7 @@ class OpenAIImageEditConfig(BaseImageEditConfig):
"image",
"prompt",
"background",
"input_fidelity",
"mask",
"model",
"n",
@@ -0,0 +1,19 @@
"""
OpenAI Responses API token counting implementation.
"""
from litellm.llms.openai.responses.count_tokens.handler import (
OpenAICountTokensHandler,
)
from litellm.llms.openai.responses.count_tokens.token_counter import (
OpenAITokenCounter,
)
from litellm.llms.openai.responses.count_tokens.transformation import (
OpenAICountTokensConfig,
)
__all__ = [
"OpenAICountTokensHandler",
"OpenAICountTokensConfig",
"OpenAITokenCounter",
]
@@ -0,0 +1,105 @@
"""
OpenAI Responses API token counting handler.
Uses httpx for HTTP requests to OpenAI's /v1/responses/input_tokens endpoint.
"""
import json
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.openai.common_utils import OpenAIError
from litellm.llms.openai.responses.count_tokens.transformation import (
OpenAICountTokensConfig,
)
class OpenAICountTokensHandler(OpenAICountTokensConfig):
"""
Handler for OpenAI Responses API token counting requests.
"""
async def handle_count_tokens_request(
self,
model: str,
input: Union[str, List[Any]],
api_key: str,
api_base: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
instructions: Optional[str] = None,
) -> Dict[str, Any]:
"""
Handle a token counting request to OpenAI's Responses API.
Returns:
Dictionary containing {"input_tokens": <number>}
Raises:
OpenAIError: If the API request fails
"""
try:
self.validate_request(model, input)
verbose_logger.debug(
f"Processing OpenAI CountTokens request for model: {model}"
)
request_body = self.transform_request_to_count_tokens(
model=model,
input=input,
tools=tools,
instructions=instructions,
)
endpoint_url = self.get_openai_count_tokens_endpoint(api_base)
verbose_logger.debug(f"Making request to: {endpoint_url}")
headers = self.get_required_headers(api_key)
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OPENAI
)
request_timeout = timeout if timeout is not None else litellm.request_timeout
response = await async_client.post(
endpoint_url,
headers=headers,
json=request_body,
timeout=request_timeout,
)
verbose_logger.debug(f"Response status: {response.status_code}")
if response.status_code != 200:
error_text = response.text
verbose_logger.error(f"OpenAI API error: {error_text}")
raise OpenAIError(
status_code=response.status_code,
message=error_text,
)
openai_response = response.json()
verbose_logger.debug(f"OpenAI response: {openai_response}")
return openai_response
except OpenAIError:
raise
except httpx.HTTPStatusError as e:
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
raise OpenAIError(
status_code=e.response.status_code,
message=e.response.text,
)
except (httpx.RequestError, json.JSONDecodeError, ValueError) as e:
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
raise OpenAIError(
status_code=500,
message=f"CountTokens processing error: {str(e)}",
)
@@ -0,0 +1,118 @@
"""
OpenAI Token Counter implementation using the Responses API /input_tokens endpoint.
"""
import os
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_logger
from litellm.llms.base_llm.base_utils import BaseTokenCounter
from litellm.llms.openai.common_utils import OpenAIError
from litellm.llms.openai.responses.count_tokens.handler import (
OpenAICountTokensHandler,
)
from litellm.llms.openai.responses.count_tokens.transformation import (
OpenAICountTokensConfig,
)
from litellm.types.utils import LlmProviders, TokenCountResponse
# Global handler instance - reuse across all token counting requests
openai_count_tokens_handler = OpenAICountTokensHandler()
class OpenAITokenCounter(BaseTokenCounter):
"""Token counter implementation for OpenAI provider using the Responses API."""
def should_use_token_counting_api(
self,
custom_llm_provider: Optional[str] = None,
) -> bool:
return custom_llm_provider == LlmProviders.OPENAI.value
async def count_tokens(
self,
model_to_use: str,
messages: Optional[List[Dict[str, Any]]],
contents: Optional[List[Dict[str, Any]]],
deployment: Optional[Dict[str, Any]] = None,
request_model: str = "",
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Optional[TokenCountResponse]:
"""
Count tokens using OpenAI's Responses API /input_tokens endpoint.
"""
if not messages:
return None
deployment = deployment or {}
litellm_params = deployment.get("litellm_params", {})
# Get OpenAI API key from deployment config or environment
api_key = litellm_params.get("api_key")
if not api_key:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
verbose_logger.warning("No OpenAI API key found for token counting")
return None
api_base = litellm_params.get("api_base")
# Convert chat messages to Responses API input format
input_items, instructions = OpenAICountTokensConfig.messages_to_responses_input(
messages
)
# Use system param if instructions not extracted from messages
if instructions is None and system is not None:
instructions = system if isinstance(system, str) else str(system)
# If no input items were produced (e.g., system-only messages), fall back to local counting
if not input_items:
return None
try:
result = await openai_count_tokens_handler.handle_count_tokens_request(
model=model_to_use,
input=input_items if input_items is not None else [],
api_key=api_key,
api_base=api_base,
tools=tools,
instructions=instructions,
)
if result is not None:
return TokenCountResponse(
total_tokens=result.get("input_tokens", 0),
request_model=request_model,
model_used=model_to_use,
tokenizer_type="openai_api",
original_response=result,
)
except OpenAIError as e:
verbose_logger.warning(
f"OpenAI CountTokens API error: status={e.status_code}, message={e.message}"
)
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="openai_api",
error=True,
error_message=e.message,
status_code=e.status_code,
)
except Exception as e:
verbose_logger.warning(f"Error calling OpenAI CountTokens API: {e}")
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="openai_api",
error=True,
error_message=str(e),
status_code=500,
)
return None
@@ -0,0 +1,158 @@
"""
OpenAI Responses API token counting transformation logic.
This module handles the transformation of requests to OpenAI's /v1/responses/input_tokens endpoint.
"""
from typing import Any, Dict, List, Optional, Union
class OpenAICountTokensConfig:
"""
Configuration and transformation logic for OpenAI Responses API token counting.
OpenAI Responses API Token Counting Specification:
- Endpoint: POST https://api.openai.com/v1/responses/input_tokens
- Response: {"input_tokens": <number>}
"""
def get_openai_count_tokens_endpoint(self, api_base: Optional[str] = None) -> str:
base = api_base or "https://api.openai.com/v1"
base = base.rstrip("/")
return f"{base}/responses/input_tokens"
def transform_request_to_count_tokens(
self,
model: str,
input: Union[str, List[Any]],
tools: Optional[List[Dict[str, Any]]] = None,
instructions: Optional[str] = None,
) -> Dict[str, Any]:
"""
Transform request to OpenAI Responses API token counting format.
The Responses API uses `input` (not `messages`) and `instructions` (not `system`).
"""
request: Dict[str, Any] = {
"model": model,
"input": input,
}
if instructions is not None:
request["instructions"] = instructions
if tools is not None:
request["tools"] = self._transform_tools_for_responses_api(tools)
return request
def get_required_headers(self, api_key: str) -> Dict[str, str]:
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
def validate_request(
self, model: str, input: Union[str, List[Any]]
) -> None:
if not model:
raise ValueError("model parameter is required")
if not input:
raise ValueError("input parameter is required")
@staticmethod
def _transform_tools_for_responses_api(
tools: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
Transform OpenAI chat tools format to Responses API tools format.
Chat format: {"type": "function", "function": {"name": "...", "parameters": {...}}}
Responses format: {"type": "function", "name": "...", "parameters": {...}}
"""
transformed = []
for tool in tools:
if tool.get("type") == "function" and "function" in tool:
func = tool["function"]
item: Dict[str, Any] = {
"type": "function",
"name": func.get("name", ""),
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
if "strict" in func:
item["strict"] = func["strict"]
transformed.append(item)
else:
# Pass through non-function tools (e.g., web_search, file_search)
transformed.append(tool)
return transformed
@staticmethod
def messages_to_responses_input(
messages: List[Dict[str, Any]],
) -> tuple:
"""
Convert standard chat messages format to OpenAI Responses API input format.
Returns:
(input_items, instructions) tuple where instructions is extracted
from system/developer messages.
"""
input_items: List[Dict[str, Any]] = []
instructions_parts: List[str] = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content") or ""
if role in ("system", "developer"):
# Extract system/developer messages as instructions
if isinstance(content, str):
instructions_parts.append(content)
elif isinstance(content, list):
# Handle content blocks - extract text
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
elif isinstance(block, str):
text_parts.append(block)
instructions_parts.append("\n".join(text_parts))
elif role == "user":
if isinstance(content, list):
# Extract text from content blocks for Responses API
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
elif isinstance(block, str):
text_parts.append(block)
content = "\n".join(text_parts)
input_items.append({"role": "user", "content": content})
elif role == "assistant":
# Map tool_calls to Responses API function_call items
tool_calls = msg.get("tool_calls")
if content:
input_items.append({"role": "assistant", "content": content})
if tool_calls:
for tc in tool_calls:
func = tc.get("function", {})
input_items.append({
"type": "function_call",
"call_id": tc.get("id", ""),
"name": func.get("name", ""),
"arguments": func.get("arguments", ""),
})
elif not content:
input_items.append({"role": "assistant", "content": content})
elif role == "tool":
input_items.append({
"type": "function_call_output",
"call_id": msg.get("tool_call_id", ""),
"output": content if isinstance(content, str) else str(content),
})
instructions = "\n".join(instructions_parts) if instructions_parts else None
return input_items, instructions
+35 -3
View File
@@ -10,8 +10,9 @@ Instead of creating a full Python module for simple OpenAI-compatible providers,
- `providers.json` - Configuration file for all JSON-based providers
- `json_loader.py` - Loads and parses the JSON configuration
- `dynamic_config.py` - Generates Python config classes from JSON
- `chat/` - Existing OpenAI-like chat completion handlers
- `dynamic_config.py` - Generates Python config classes from JSON (chat + responses)
- `chat/` - OpenAI-like chat completion handlers
- `responses/` - OpenAI-like Responses API handlers
## Adding a New Provider
@@ -96,6 +97,32 @@ response = litellm.completion(
)
```
## Responses API Support
Providers that support the OpenAI Responses API (`/v1/responses`) can declare it via `supported_endpoints`:
```json
{
"your_provider": {
"base_url": "https://api.yourprovider.com/v1",
"api_key_env": "YOUR_PROVIDER_API_KEY",
"supported_endpoints": ["/v1/chat/completions", "/v1/responses"]
}
}
```
This enables `litellm.responses(model="your_provider/model-name", ...)` with zero Python code.
The provider inherits all request/response handling from OpenAI's Responses API config.
If `supported_endpoints` is omitted, it defaults to `[]` (only chat completions, which is always enabled for JSON providers).
### How It Works
1. `json_loader.py` checks `supported_endpoints` for `/v1/responses`
2. `dynamic_config.py` generates a responses config class (inherits from `OpenAIResponsesAPIConfig`)
3. `ProviderConfigManager.get_provider_responses_api_config()` returns the generated config
4. Request/response transformation is inherited from OpenAI — no custom code needed
## Benefits
- **Simple**: 2-5 lines of JSON vs 100+ lines of Python
@@ -112,6 +139,10 @@ Use a Python config class if you need:
- Provider-specific streaming logic
- Advanced tool calling transformations
For providers that are *mostly* OpenAI-compatible but need small overrides (e.g. preset model handling),
you can inherit from `OpenAIResponsesAPIConfig` and override only what's needed — see
`litellm/llms/perplexity/responses/transformation.py` for a minimal example (~40 lines).
## Implementation Details
### How It Works
@@ -125,5 +156,6 @@ Use a Python config class if you need:
The JSON system is integrated at:
- `litellm/litellm_core_utils/get_llm_provider_logic.py` - Provider resolution
- `litellm/utils.py` - ProviderConfigManager
- `litellm/utils.py` - ProviderConfigManager (chat + responses)
- `litellm/responses/main.py` - Responses API routing
- `litellm/constants.py` - openai_compatible_providers list
@@ -172,3 +172,63 @@ def create_config_class(provider: SimpleProviderConfig):
return provider.slug
return JSONProviderConfig
_responses_config_cache: dict = {}
def create_responses_config_class(provider: SimpleProviderConfig):
"""Generate a Responses API config class dynamically from JSON configuration.
Parallel to create_config_class() but for /v1/responses endpoints.
Classes are cached per provider slug to avoid regeneration on every request.
"""
if provider.slug in _responses_config_cache:
return _responses_config_cache[provider.slug]
from litellm.llms.openai_like.responses.transformation import (
OpenAILikeResponsesConfig,
)
from litellm.types.router import GenericLiteLLMParams
class JSONProviderResponsesConfig(OpenAILikeResponsesConfig):
@property
def custom_llm_provider(self): # type: ignore[override]
return provider.slug
def validate_environment(
self,
headers: dict,
model: str,
litellm_params: Optional[GenericLiteLLMParams],
) -> dict:
litellm_params = litellm_params or GenericLiteLLMParams()
api_key = (
litellm_params.api_key
or get_secret_str(provider.api_key_env)
)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
if not api_base:
if provider.api_base_env:
api_base = get_secret_str(provider.api_base_env)
if not api_base:
api_base = provider.base_url
if api_base is None:
raise ValueError(
f"api_base is required for provider {provider.slug}"
)
api_base = api_base.rstrip("/")
return f"{api_base}/responses"
_responses_config_cache[provider.slug] = JSONProviderResponsesConfig
return JSONProviderResponsesConfig
+9
View File
@@ -21,6 +21,7 @@ class SimpleProviderConfig:
self.param_mappings = data.get("param_mappings", {})
self.constraints = data.get("constraints", {})
self.special_handling = data.get("special_handling", {})
self.supported_endpoints = data.get("supported_endpoints", [])
class JSONProviderRegistry:
@@ -66,6 +67,14 @@ class JSONProviderRegistry:
"""Check if a provider is defined via JSON"""
return slug in cls._providers
@classmethod
def supports_responses_api(cls, slug: str) -> bool:
"""Check if a JSON provider supports the Responses API"""
provider = cls._providers.get(slug)
if provider is None:
return False
return "/v1/responses" in provider.supported_endpoints
@classmethod
def list_providers(cls) -> list:
"""List all registered provider slugs"""
@@ -0,0 +1,5 @@
from litellm.llms.openai_like.responses.transformation import (
OpenAILikeResponsesConfig,
)
__all__ = ["OpenAILikeResponsesConfig"]
@@ -0,0 +1,51 @@
"""
OpenAI-like Responses API transformation.
Base class for JSON-declared providers that support the /v1/responses endpoint.
Inherits everything from OpenAIResponsesAPIConfig; subclasses only override
provider-specific resolution (slug, API key env var, base URL).
"""
from typing import Optional, Union
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
class OpenAILikeResponsesConfig(OpenAIResponsesAPIConfig):
"""
Responses API config for OpenAI-compatible providers declared via JSON.
Concrete per-provider classes are generated dynamically in dynamic_config.py.
This base provides the three overridable hooks that the dynamic generator
fills in: custom_llm_provider, validate_environment, get_complete_url.
"""
@property
def custom_llm_provider(self) -> Union[str, LlmProviders]: # type: ignore[override]
return "openai_like"
def validate_environment(
self,
headers: dict,
model: str,
litellm_params: Optional[GenericLiteLLMParams],
) -> dict:
litellm_params = litellm_params or GenericLiteLLMParams()
api_key = litellm_params.api_key or get_secret_str("OPENAI_LIKE_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE")
if not api_base:
raise ValueError("api_base is required for openai_like provider")
api_base = api_base.rstrip("/")
return f"{api_base}/responses"
@@ -1,54 +1,31 @@
"""
Transformation logic for Perplexity Agent API (Responses API)
Perplexity Responses API OpenAI-compatible.
This module handles the translation between OpenAI's Responses API format
and Perplexity's Responses API format, which supports:
- Third-party model access (OpenAI, Anthropic, Google, xAI, etc.)
- Presets for optimized configurations
- Web search and URL fetching tools
- Reasoning effort control
- Instructions parameter for system-level guidance
The only provider quirks:
- cost returned as dict handled by ResponseAPIUsage.parse_cost validator
- preset models (preset/pro-search) handled by transform_responses_api_request
- HTTP 200 with status:"failed" raised as exception in transform_response_api_response
Ref: https://docs.perplexity.ai/api-reference/responses-post
"""
from typing import Any, Dict, List, Optional, Union
import httpx
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
ResponseAPIUsage,
ResponseInputParam,
ResponsesAPIOptionalRequestParams,
ResponsesAPIResponse,
ResponsesAPIStreamingResponse,
)
from litellm.types.llms.openai import ResponseInputParam, ResponsesAPIResponse
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
"""
Configuration for Perplexity Agent API (Responses API)
Reference: https://docs.perplexity.ai/docs/agent-api/overview
"""
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.PERPLEXITY
def get_supported_openai_params(self, model: str) -> list:
"""
Perplexity Responses API supports a different set of parameters
Ref: https://docs.perplexity.ai/api-reference/responses-post
Params aligned with response-echo fields and Open Responses spec.
"""
"""Ref: https://docs.perplexity.ai/api-reference/responses-post"""
return [
"max_output_tokens",
"stream",
@@ -56,200 +33,45 @@ class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
"top_p",
"tools",
"reasoning",
"preset",
"instructions",
"models", # Model fallback support
"tool_choice",
"parallel_tool_calls",
"max_tool_calls",
"text",
"previous_response_id",
"store",
"background",
"truncation",
"metadata",
"safety_identifier",
"user",
"stream_options",
"top_logprobs",
"prompt_cache_key",
"frequency_penalty",
"presence_penalty",
"service_tier",
"models",
]
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.PERPLEXITY
def validate_environment(
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
"""Validate environment and set up headers"""
# Get API key from environment
api_key = get_secret_str("PERPLEXITYAI_API_KEY") or get_secret_str(
"PERPLEXITY_API_KEY"
litellm_params = litellm_params or GenericLiteLLMParams()
api_key = (
litellm_params.api_key
or get_secret_str("PERPLEXITYAI_API_KEY")
or get_secret_str("PERPLEXITY_API_KEY")
)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
headers["Content-Type"] = "application/json"
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""Get the complete URL for the Perplexity Responses API"""
if api_base is None:
api_base = (
get_secret_str("PERPLEXITY_API_BASE") or "https://api.perplexity.ai"
)
def get_complete_url(self, api_base: Optional[str], litellm_params: dict) -> str:
api_base = api_base or get_secret_str("PERPLEXITY_API_BASE") or "https://api.perplexity.ai"
return f"{api_base.rstrip('/')}/v1/responses"
# Ensure api_base doesn't end with a slash
api_base = api_base.rstrip("/")
# Add the responses endpoint
return f"{api_base}/v1/responses"
def map_openai_params( # noqa: PLR0915
self,
response_api_optional_params: ResponsesAPIOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""
Map OpenAI Responses API parameters to Perplexity format
Key differences:
- Supports 'preset' parameter for predefined configurations
- Supports 'instructions' parameter for system-level guidance
- Tools are specified differently (web_search, fetch_url)
"""
mapped_params: Dict[str, Any] = {}
# Map standard parameters
if response_api_optional_params.get("max_output_tokens"):
mapped_params["max_output_tokens"] = response_api_optional_params[
"max_output_tokens"
]
if response_api_optional_params.get("temperature"):
mapped_params["temperature"] = response_api_optional_params["temperature"]
if response_api_optional_params.get("top_p"):
mapped_params["top_p"] = response_api_optional_params["top_p"]
if response_api_optional_params.get("stream"):
mapped_params["stream"] = response_api_optional_params["stream"]
if response_api_optional_params.get("stream_options"):
mapped_params["stream_options"] = response_api_optional_params[
"stream_options"
]
# Map Perplexity-specific parameters (using .get() with Any dict access)
preset = response_api_optional_params.get("preset") # type: ignore
if preset:
mapped_params["preset"] = preset
instructions = response_api_optional_params.get("instructions") # type: ignore
if instructions:
mapped_params["instructions"] = instructions
if response_api_optional_params.get("reasoning"):
mapped_params["reasoning"] = response_api_optional_params["reasoning"]
tools = response_api_optional_params.get("tools")
if tools:
# Convert tools to list of dicts for transformation
tools_list = [dict(tool) if hasattr(tool, "__dict__") else tool for tool in tools] # type: ignore
mapped_params["tools"] = self._transform_tools(tools_list) # type: ignore
# Tool control
if response_api_optional_params.get("tool_choice"):
mapped_params["tool_choice"] = response_api_optional_params["tool_choice"]
if response_api_optional_params.get("parallel_tool_calls") is not None:
mapped_params["parallel_tool_calls"] = response_api_optional_params[
"parallel_tool_calls"
]
if response_api_optional_params.get("max_tool_calls"):
mapped_params["max_tool_calls"] = response_api_optional_params[
"max_tool_calls"
]
# Structured outputs
text_param = response_api_optional_params.get("text")
if text_param:
mapped_params["text"] = text_param
# Conversation continuity
if response_api_optional_params.get("previous_response_id"):
mapped_params["previous_response_id"] = response_api_optional_params[
"previous_response_id"
]
# Storage and lifecycle
if response_api_optional_params.get("store") is not None:
mapped_params["store"] = response_api_optional_params["store"]
if response_api_optional_params.get("background") is not None:
mapped_params["background"] = response_api_optional_params["background"]
if response_api_optional_params.get("truncation"):
mapped_params["truncation"] = response_api_optional_params["truncation"]
# Metadata
if response_api_optional_params.get("metadata"):
mapped_params["metadata"] = response_api_optional_params["metadata"]
if response_api_optional_params.get("safety_identifier"):
mapped_params["safety_identifier"] = response_api_optional_params[
"safety_identifier"
]
if response_api_optional_params.get("user"):
mapped_params["user"] = response_api_optional_params["user"]
# Additional
if response_api_optional_params.get("top_logprobs") is not None:
mapped_params["top_logprobs"] = response_api_optional_params["top_logprobs"]
if response_api_optional_params.get("prompt_cache_key"):
mapped_params["prompt_cache_key"] = response_api_optional_params[
"prompt_cache_key"
]
if response_api_optional_params.get("frequency_penalty") is not None:
mapped_params["frequency_penalty"] = response_api_optional_params[
"frequency_penalty" # type: ignore[typeddict-item]
]
if response_api_optional_params.get("presence_penalty") is not None:
mapped_params["presence_penalty"] = response_api_optional_params[
"presence_penalty" # type: ignore[typeddict-item]
]
if response_api_optional_params.get("service_tier"):
mapped_params["service_tier"] = response_api_optional_params["service_tier"]
return mapped_params
def _transform_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Transform tools to Perplexity format.
Perplexity supports (per public OpenAPI spec):
- web_search: Performs web searches
- fetch_url: Fetches content from URLs
- function: Function Calling
"""
perplexity_tools = []
for tool in tools:
if isinstance(tool, dict):
tool_type = tool.get("type", "")
# Direct Perplexity tool format
if tool_type in ["web_search", "fetch_url"]:
perplexity_tools.append(tool)
# Function tools: Perplexity supports them natively
elif tool_type == "function":
perplexity_tools.append(tool)
return perplexity_tools
def _ensure_message_type(
self, input: Union[str, ResponseInputParam]
) -> Union[str, List[Dict[str, Any]]]:
"""Ensure list input items have type='message' (required by Perplexity)."""
if isinstance(input, str):
return input
if isinstance(input, list):
result = []
for item in input:
if isinstance(item, dict) and "type" not in item:
item = {**item, "type": "message"}
result.append(item)
return result
return input
def transform_responses_api_request(
self,
@@ -259,62 +81,23 @@ class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""
Transform request to Perplexity Responses API format
"""
# Check if the model is a preset (format: preset/preset-name)
"""Handle preset/ model prefix: send as {"preset": name} instead of {"model": name}."""
input = self._ensure_message_type(input)
if model.startswith("preset/"):
preset_name = model.replace("preset/", "")
data = {
"preset": preset_name,
"input": self._format_input(input),
input = self._validate_input_param(input)
data: Dict = {
"preset": model[len("preset/"):],
"input": input,
}
# Check if preset is explicitly provided in params
elif response_api_optional_request_params.get("preset"):
data = {
"preset": response_api_optional_request_params.pop("preset"),
"input": self._format_input(input),
}
else:
# Full request format for third-party models
data = {
"model": model,
"input": self._format_input(input),
}
# Add all optional parameters
for key, value in response_api_optional_request_params.items():
data[key] = value
return data
def _format_input(
self, input: Union[str, ResponseInputParam]
) -> Union[str, List[Dict[str, Any]]]:
"""
Format input for Perplexity Responses API
The API accepts either:
- A simple string for single-turn queries
- An array of message objects for multi-turn conversations
"""
if isinstance(input, str):
return input
# Handle ResponseInputParam format
if isinstance(input, list):
formatted_messages = []
for item in input:
if isinstance(item, dict):
formatted_message = {
"type": "message",
"role": item.get("role"),
"content": item.get("content", ""),
}
formatted_messages.append(formatted_message)
return formatted_messages
return str(input)
data.update(response_api_optional_request_params)
return data
return super().transform_responses_api_request(
model=model,
input=input,
response_api_optional_request_params=response_api_optional_request_params,
litellm_params=litellm_params,
headers=headers,
)
def transform_response_api_response(
self,
@@ -322,174 +105,27 @@ class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
"""
Transform Perplexity Responses API response to OpenAI Responses API format
"""
"""Check for Perplexity's status:'failed' on HTTP 200 before delegating to base."""
try:
raw_response_json = raw_response.json()
except Exception as e:
raise BaseLLMException(
status_code=raw_response.status_code,
message=f"Failed to parse response: {str(e)}",
)
# Check for error status
status = raw_response_json.get("status")
if status == "failed":
error = raw_response_json.get("error", {})
error_message = error.get("message", "Unknown error")
raise BaseLLMException(
status_code=raw_response.status_code,
message=error_message,
)
# Transform usage to handle Perplexity's cost structure
usage_data = raw_response_json.get("usage", {})
transformed_usage_dict = self._transform_usage(usage_data)
# Convert usage dict to ResponseAPIUsage object
usage_obj = (
ResponseAPIUsage(**transformed_usage_dict)
if transformed_usage_dict
else None
)
# Map Perplexity response to OpenAI Responses API format
response = ResponsesAPIResponse(
id=raw_response_json.get("id", ""),
object="response",
created_at=raw_response_json.get("created_at", 0),
status=raw_response_json.get("status", "completed"),
model=raw_response_json.get("model", model),
output=raw_response_json.get("output", []),
usage=usage_obj,
)
return response
def _transform_usage(self, usage_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Transform Perplexity usage data to OpenAI format
Perplexity returns:
{
"input_tokens": 100,
"output_tokens": 200,
"total_tokens": 300,
"cost": {
"currency": "USD",
"input_cost": 0.0001,
"output_cost": 0.0002,
"total_cost": 0.0003
}
}
OpenAI expects:
{
"input_tokens": 100,
"output_tokens": 200,
"total_tokens": 300,
"cost": 0.0003
}
"""
transformed = {
"input_tokens": usage_data.get("input_tokens", 0),
"output_tokens": usage_data.get("output_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
# Transform cost from Perplexity format (dict) to OpenAI format (float)
cost_obj = usage_data.get("cost")
if isinstance(cost_obj, dict) and "total_cost" in cost_obj:
transformed["cost"] = cost_obj["total_cost"]
verbose_logger.debug(
"Transformed Perplexity cost object to float: %s -> %s",
cost_obj,
cost_obj["total_cost"],
)
elif cost_obj is not None:
# If cost is already a float/number, use it as-is
transformed["cost"] = cost_obj
# Add input_tokens_details if present
if "input_tokens_details" in usage_data:
transformed["input_tokens_details"] = usage_data["input_tokens_details"]
# Add output_tokens_details if present
if "output_tokens_details" in usage_data:
transformed["output_tokens_details"] = usage_data["output_tokens_details"]
return transformed
def transform_streaming_response(
self,
model: str,
parsed_chunk: dict,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIStreamingResponse:
"""
Transform a parsed streaming response chunk into a ResponsesAPIStreamingResponse
"""
# Get the event type from the chunk
verbose_logger.debug("Raw Perplexity Chunk=%s", parsed_chunk)
event_type = str(parsed_chunk.get("type"))
event_pydantic_model = PerplexityResponsesConfig.get_event_model_class(
event_type=event_type
)
# Transform Perplexity-specific fields to OpenAI format
parsed_chunk = self._transform_perplexity_chunk(parsed_chunk)
# Defensive: Handle error.code being null (similar to OpenAI implementation)
try:
error_obj = parsed_chunk.get("error")
if isinstance(error_obj, dict) and error_obj.get("code") is None:
# Preserve other fields, but ensure `code` is a non-null string
parsed_chunk = dict(parsed_chunk)
parsed_chunk["error"] = dict(error_obj)
parsed_chunk["error"]["code"] = "unknown_error"
except Exception:
# If anything unexpected happens here, fall back to attempting
# instantiation and let higher-level handlers manage errors.
verbose_logger.debug("Failed to coalesce error.code in parsed_chunk")
raw_response_json = None
return event_pydantic_model(**parsed_chunk)
if (
isinstance(raw_response_json, dict)
and raw_response_json.get("status") == "failed"
):
error = raw_response_json.get("error", {})
raise BaseLLMException(
status_code=raw_response.status_code,
message=error.get("message", "Unknown Perplexity error"),
)
def _transform_perplexity_chunk(self, chunk: dict) -> dict:
"""
Transform Perplexity-specific fields in a streaming chunk to OpenAI format.
This handles:
- Converting Perplexity's cost object to a simple float
"""
# Make a copy to avoid modifying the original
chunk = dict(chunk)
# Transform usage.cost from Perplexity format to OpenAI format
# Perplexity: {"currency": "USD", "input_cost": 0.0001, "output_cost": 0.0002, "total_cost": 0.0003}
# OpenAI: 0.0003 (just the total_cost as a float)
try:
response_obj = chunk.get("response")
if isinstance(response_obj, dict):
usage_obj = response_obj.get("usage")
if isinstance(usage_obj, dict):
cost_obj = usage_obj.get("cost")
if isinstance(cost_obj, dict) and "total_cost" in cost_obj:
# Replace the cost object with just the total_cost value
chunk = dict(chunk)
chunk["response"] = dict(response_obj)
chunk["response"]["usage"] = dict(usage_obj)
chunk["response"]["usage"]["cost"] = cost_obj["total_cost"]
verbose_logger.debug(
"Transformed Perplexity cost object to float: %s -> %s",
cost_obj,
cost_obj["total_cost"],
)
except Exception as e:
# If transformation fails, log and continue with original chunk
verbose_logger.debug("Failed to transform Perplexity cost object: %s", e)
return chunk
return super().transform_response_api_response(
model=model,
raw_response=raw_response,
logging_obj=logging_obj,
)
def supports_native_websocket(self) -> bool:
"""Perplexity does not support native WebSocket for Responses API"""
+36 -20
View File
@@ -583,17 +583,35 @@ class SagemakerLLM(BaseAWSLLM):
### BOTO3 INIT
import boto3
# Use _load_credentials to support role assumption (aws_role_name, aws_session_name)
credentials, aws_region_name = self._load_credentials(optional_params)
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
# Create boto3 session with the loaded credentials
session = boto3.Session(
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
region_name=aws_region_name,
)
client = session.client(service_name="sagemaker-runtime")
if aws_access_key_id is not None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
client = boto3.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME")
or aws_region_name # get region from config file if specified
or "us-west-2" # default to us-west-2 if region not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
@@ -610,9 +628,7 @@ class SagemakerLLM(BaseAWSLLM):
#### EMBEDDING LOGIC
# Transform request based on model type
provider_config = SagemakerEmbeddingConfig.get_model_config(model)
request_data = provider_config.transform_embedding_request(
model, input, optional_params, {}
)
request_data = provider_config.transform_embedding_request(model, input, optional_params, {})
data = json.dumps(request_data).encode("utf-8")
## LOGGING
@@ -657,19 +673,19 @@ class SagemakerLLM(BaseAWSLLM):
)
print_verbose(f"raw model_response: {response}")
# Transform response based on model type
from httpx import Response as HttpxResponse
# Create a mock httpx Response object for the transformation
mock_response = HttpxResponse(
status_code=200,
content=json.dumps(response).encode("utf-8"),
headers={"content-type": "application/json"},
content=json.dumps(response).encode('utf-8'),
headers={"content-type": "application/json"}
)
model_response = EmbeddingResponse()
# Use the request_data that was already transformed above
return provider_config.transform_embedding_response(
model=model,
@@ -679,5 +695,5 @@ class SagemakerLLM(BaseAWSLLM):
api_key=None,
request_data=request_data,
optional_params=optional_params,
litellm_params=litellm_params or {},
litellm_params=litellm_params or {}
)
+11 -10
View File
@@ -208,27 +208,28 @@ class SnowflakeConfig(SnowflakeBaseConfig, OpenAIGPTConfig):
def _transform_tool_choice(
self, tool_choice: Union[str, Dict[str, Any]]
) -> Union[str, Dict[str, Any]]:
) -> Dict[str, Any]:
"""
Transform OpenAI tool_choice format to Snowflake format.
Snowflake requires tool_choice to be an object, not a string.
Ref: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/reference/cortex-inference#post--api-v2-cortex-inference-complete-req-body-schema
Args:
tool_choice: Tool choice in OpenAI format (str or dict)
Returns:
Tool choice in Snowflake format
Tool choice in Snowflake format (always an object)
OpenAI format:
{"type": "function", "function": {"name": "get_weather"}}
OpenAI format (string): "auto", "required", "none"
OpenAI format (object): {"type": "function", "function": {"name": "get_weather"}}
Snowflake format:
{"type": "tool", "name": ["get_weather"]}
Note: String values ("auto", "required", "none") pass through unchanged.
Snowflake format (string values become objects): {"type": "auto"}
Snowflake format (specific tool): {"type": "tool", "name": ["get_weather"]}
"""
if isinstance(tool_choice, str):
# "auto", "required", "none" pass through as-is
return tool_choice
# Snowflake requires object format: {"type": "auto"} not string "auto"
return {"type": tool_choice}
if isinstance(tool_choice, dict):
if tool_choice.get("type") == "function":
+36 -9
View File
@@ -249,23 +249,27 @@ def _get_embedding_url(
- bge/endpoint_id -> strips to endpoint_id for endpoints/ routing
- numeric model -> routes to endpoints/
- regular model -> routes to publishers/google/models/
- models with uses_embed_content flag -> use embedContent endpoint instead of predict
"""
endpoint = "predict"
# Strip routing prefixes (bge/, gemma/, etc.) for endpoint URL construction
original_model = model
model = get_vertex_base_model_name(model=model)
# Get base URL (handles global vs regional)
try:
model_info = litellm.get_model_info(
model=original_model,
custom_llm_provider="vertex_ai",
)
uses_embed_content = model_info.get("uses_embed_content", False)
except Exception:
uses_embed_content = False
endpoint = "embedContent" if uses_embed_content else "predict"
base_url = get_vertex_base_url(vertex_location)
if model.isdigit():
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
# https://aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/global/endpoints/$ENDPOINT_ID:predict
url = f"{base_url}/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
else:
# Regular model -> publisher model
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/publishers/google/models/{model}:predict
# https://aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/global/publishers/google/models/{model}:predict
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
return url, endpoint
@@ -518,6 +522,29 @@ def _build_vertex_schema(parameters: dict, add_property_ordering: bool = False):
return parameters
def _build_vertex_schema_for_gemini_2(parameters: dict) -> dict:
"""
Minimal schema builder for Gemini 2.0+ tool parameters.
Gemini 2.0+ accepts standard JSON Schema natively in tool parameters,
including lowercase types, anyOf with null, and bare {} (TYPE_UNSPECIFIED).
The only transformation needed is resolving $ref/$defs, which Gemini does
NOT support in tool parameters (returns 400).
This avoids the harmful transforms in _build_vertex_schema that break
JsonValue/Any semantics by coercing {} to {"type": "object"}.
"""
valid_schema_fields = set(get_type_hints(Schema).keys())
parameters = dict(parameters) # shallow copy to avoid mutating caller's dict
defs = parameters.pop("$defs", {})
unpack_defs(parameters, defs)
parameters = filter_schema_fields(parameters, valid_schema_fields)
return parameters
def _build_json_schema(parameters: dict) -> dict:
"""
Build a JSON Schema for use with Gemini's responseJsonSchema parameter.
@@ -4,13 +4,16 @@ import httpx
import litellm
from litellm.caching.caching import Cache, LiteLLMCacheType
from litellm.constants import MINIMUM_PROMPT_CACHE_TOKEN_COUNT
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm._logging import verbose_logger
from litellm.llms.openai.openai import AllMessageValues
from litellm.utils import is_prompt_caching_valid_prompt
from litellm.types.llms.vertex_ai import (
CachedContentListAllResponseBody,
VertexAICachedContentResponseObject,
@@ -315,6 +318,20 @@ class ContextCachingEndpoints(VertexBase):
if len(cached_messages) == 0:
return messages, optional_params, None
# Gemini requires a minimum of 1024 tokens for context caching.
# Skip caching if the cached content is too small to avoid API errors.
if not is_prompt_caching_valid_prompt(
model=model,
messages=cached_messages,
custom_llm_provider=custom_llm_provider,
):
verbose_logger.debug(
"Vertex AI context caching: cached content is below minimum token "
"count (%d). Skipping context caching.",
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
)
return messages, optional_params, None
tools = optional_params.pop("tools", None)
## AUTHORIZATION ##
@@ -447,6 +464,20 @@ class ContextCachingEndpoints(VertexBase):
if len(cached_messages) == 0:
return messages, optional_params, None
# Gemini requires a minimum of 1024 tokens for context caching.
# Skip caching if the cached content is too small to avoid API errors.
if not is_prompt_caching_valid_prompt(
model=model,
messages=cached_messages,
custom_llm_provider=custom_llm_provider,
):
verbose_logger.debug(
"Vertex AI context caching: cached content is below minimum token "
"count (%d). Skipping context caching.",
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
)
return messages, optional_params, None
tools = optional_params.pop("tools", None)
## AUTHORIZATION ##
@@ -77,6 +77,60 @@ def _convert_detail_to_media_resolution_enum(
return None
def _get_highest_media_resolution(
current: Optional[str], new_detail: Optional[str]
) -> Optional[str]:
"""
Compare two media resolution values and return the highest one.
Resolution hierarchy: ultra_high > high > medium > low > None
"""
resolution_priority = {"ultra_high": 4, "high": 3, "medium": 2, "low": 1}
current_priority = resolution_priority.get(current, 0) if current else 0
new_priority = resolution_priority.get(new_detail, 0) if new_detail else 0
if new_priority > current_priority:
return new_detail
return current
def _extract_max_media_resolution_from_messages(
messages: List[AllMessageValues],
) -> Optional[str]:
"""
Extract the highest media resolution (detail) from image content in messages.
This is used to set the global media_resolution in generation_config for
Gemini 2.x models which don't support per-part media resolution.
Args:
messages: List of messages in OpenAI format
Returns:
The highest detail level found ("high", "low", or None)
"""
max_resolution: Optional[str] = None
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
for item in content:
if not isinstance(item, dict):
continue
detail: Optional[str] = None
if item.get("type") == "image_url":
image_url = item.get("image_url")
if isinstance(image_url, dict):
detail = image_url.get("detail")
elif item.get("type") == "file":
file_obj = item.get("file")
if isinstance(file_obj, dict):
detail = file_obj.get("detail")
if detail:
max_resolution = _get_highest_media_resolution(
max_resolution, detail
)
return max_resolution
def _apply_gemini_3_metadata(
part: PartType,
model: Optional[str],
@@ -539,10 +593,6 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
raise e
# Keys that LiteLLM consumes internally and must never be forwarded to the
_LITELLM_INTERNAL_EXTRA_BODY_KEYS: frozenset = frozenset({"cache", "tags"})
def _pop_and_merge_extra_body(data: RequestBody, optional_params: dict) -> None:
"""Pop extra_body from optional_params and shallow-merge into data, deep-merging dict values."""
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
@@ -561,7 +611,7 @@ def _pop_and_merge_extra_body(data: RequestBody, optional_params: dict) -> None:
data_dict[k] = v
def _transform_request_body(
def _transform_request_body( # noqa: PLR0915
messages: List[AllMessageValues],
model: str,
optional_params: dict,
@@ -639,6 +689,19 @@ def _transform_request_body(
generation_config: Optional[GenerationConfig] = GenerationConfig(
**filtered_params
)
# For Gemini 2.x models, add media_resolution to generation_config (global)
# Gemini 3+ supports per-part media_resolution, but 2.x only supports global
# Gemini 1.x does not support mediaResolution at all
if "gemini-2" in model:
max_media_resolution = _extract_max_media_resolution_from_messages(messages)
if max_media_resolution:
media_resolution_value = _convert_detail_to_media_resolution_enum(
max_media_resolution
)
if media_resolution_value and generation_config is not None:
generation_config["mediaResolution"] = media_resolution_value["level"]
data = RequestBody(contents=content)
if system_instructions is not None:
data["system_instruction"] = system_instructions
@@ -97,6 +97,7 @@ from ..common_utils import (
VertexAIError,
_build_json_schema,
_build_vertex_schema,
_build_vertex_schema_for_gemini_2,
supports_response_json_schema,
)
from ..vertex_llm_base import VertexBase
@@ -467,7 +468,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
return None
def _map_function( # noqa: PLR0915
self, value: List[dict], optional_params: dict
self, value: List[dict], optional_params: dict, model: str = ""
) -> List[Tools]:
"""
Map OpenAI-style tools/functions to Vertex AI format.
@@ -510,10 +511,21 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
"parameters" in _openai_function_object
and _openai_function_object["parameters"] is not None
and isinstance(_openai_function_object["parameters"], dict)
): # OPENAI accepts JSON Schema, Google accepts OpenAPI schema.
_openai_function_object["parameters"] = _build_vertex_schema(
_openai_function_object["parameters"]
)
):
if supports_response_json_schema(model):
# Gemini 2.0+: minimal transform (resolve $ref only)
_openai_function_object["parameters"] = (
_build_vertex_schema_for_gemini_2(
_openai_function_object["parameters"]
)
)
else:
# Gemini 1.5: full OpenAPI-style transform
_openai_function_object["parameters"] = (
_build_vertex_schema(
_openai_function_object["parameters"]
)
)
openai_function_object = _openai_function_object
@@ -1048,7 +1060,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
):
# Pass optional_params so _map_function can add toolConfig if needed
mapped_tools = self._map_function(
value=value, optional_params=optional_params
value=value, optional_params=optional_params, model=model
)
optional_params = self._add_tools_to_optional_params(
optional_params, mapped_tools
@@ -1227,27 +1239,25 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
"IMAGE_PROHIBITED_CONTENT": "The token generation was stopped as the response was flagged for prohibited image content.",
}
_GEMINI_FINISH_REASON_KEYS = frozenset({
"STOP", "MAX_TOKENS", "SAFETY", "RECITATION", "FINISH_REASON_UNSPECIFIED",
"MALFORMED_FUNCTION_CALL", "LANGUAGE", "OTHER", "BLOCKLIST",
"PROHIBITED_CONTENT", "SPII", "IMAGE_SAFETY", "IMAGE_PROHIBITED_CONTENT",
"TOO_MANY_TOOL_CALLS", "MALFORMED_RESPONSE",
})
@staticmethod
def get_finish_reason_mapping() -> Dict[str, OpenAIChatCompletionFinishReason]:
"""
Return Dictionary of finish reasons which indicate response was flagged
and what it means
Return Dictionary of Gemini/Vertex AI finish reasons and their
OpenAI-compatible mappings.
"""
from litellm.litellm_core_utils.core_helpers import _FINISH_REASON_MAP
return {
"FINISH_REASON_UNSPECIFIED": "finish_reason_unspecified",
"STOP": "stop",
"MAX_TOKENS": "length",
"SAFETY": "content_filter",
"RECITATION": "content_filter",
"LANGUAGE": "content_filter",
"OTHER": "content_filter",
"BLOCKLIST": "content_filter",
"PROHIBITED_CONTENT": "content_filter",
"SPII": "content_filter",
"MALFORMED_FUNCTION_CALL": "malformed_function_call", # openai doesn't have a way of representing this
"IMAGE_SAFETY": "content_filter",
"IMAGE_PROHIBITED_CONTENT": "content_filter",
k: v
for k, v in _FINISH_REASON_MAP.items()
if k in VertexGeminiConfig._GEMINI_FINISH_REASON_KEYS
}
def translate_exception_str(self, exception_string: str):
@@ -1766,15 +1776,14 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
chat_completion_message: Optional[ChatCompletionResponseMessage],
finish_reason: Optional[str],
) -> OpenAIChatCompletionFinishReason:
mapped_finish_reason = VertexGeminiConfig.get_finish_reason_mapping()
from litellm.litellm_core_utils.core_helpers import map_finish_reason
if chat_completion_message and chat_completion_message.get("function_call"):
return "function_call"
elif chat_completion_message and chat_completion_message.get("tool_calls"):
return "tool_calls"
elif (
finish_reason and finish_reason in mapped_finish_reason.keys()
): # vertex ai
return mapped_finish_reason[finish_reason]
elif finish_reason:
return map_finish_reason(finish_reason)
else:
return "stop"
@@ -2362,7 +2371,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
async def make_call(
client: Optional[AsyncHTTPHandler],
client: Optional[AsyncHTTPHandler], # module-level client
gemini_client: Optional[AsyncHTTPHandler], # if passed by user
api_base: str,
headers: dict,
data: str,
@@ -2370,6 +2380,8 @@ async def make_call(
messages: list,
logging_obj,
):
if gemini_client is not None:
client = gemini_client
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
@@ -2541,7 +2553,11 @@ class VertexLLM(VertexBase):
completion_stream=None,
make_call=partial(
make_call,
client=client,
gemini_client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
api_base=api_base,
headers=headers,
data=request_body_str,
@@ -3,12 +3,11 @@ Google AI Studio /batchEmbedContents Embeddings Endpoint
"""
import json
from typing import Any, Literal, Optional, Union
from typing import Any, Dict, Literal, Optional, Union
import httpx
import litellm
from litellm.types.utils import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@@ -19,15 +18,98 @@ from litellm.types.llms.vertex_ai import (
VertexAIBatchEmbeddingsRequestBody,
VertexAIBatchEmbeddingsResponseObject,
)
from litellm.types.utils import EmbeddingResponse
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .batch_embed_content_transformation import (
_is_file_reference,
_is_multimodal_input,
process_embed_content_response,
process_response,
transform_openai_input_gemini_content,
transform_openai_input_gemini_embed_content,
)
class GoogleBatchEmbeddings(VertexLLM):
def _resolve_file_references(
self,
input: EmbeddingInput,
api_key: str,
sync_handler: HTTPHandler,
) -> Dict[str, Dict[str, str]]:
"""
Resolve Gemini file references (files/...) to get mime_type and uri.
Args:
input: EmbeddingInput that may contain file references
api_key: Gemini API key
sync_handler: HTTP client
Returns:
Dict mapping file name to {mime_type, uri}
"""
input_list = [input] if isinstance(input, str) else input
resolved_files: Dict[str, Dict[str, str]] = {}
for element in input_list:
if isinstance(element, str) and _is_file_reference(element):
url = f"https://generativelanguage.googleapis.com/v1beta/{element}"
headers = {"x-goog-api-key": api_key}
response = sync_handler.get(url=url, headers=headers)
if response.status_code != 200:
raise Exception(
f"Error fetching file {element}: {response.status_code} {response.text}"
)
file_data = response.json()
resolved_files[element] = {
"mime_type": file_data.get("mimeType", ""),
"uri": file_data.get("uri", element),
}
return resolved_files
async def _async_resolve_file_references(
self,
input: EmbeddingInput,
api_key: str,
async_handler: AsyncHTTPHandler,
) -> Dict[str, Dict[str, str]]:
"""
Async version of _resolve_file_references.
Args:
input: EmbeddingInput that may contain file references
api_key: Gemini API key
async_handler: Async HTTP client
Returns:
Dict mapping file name to {mime_type, uri}
"""
input_list = [input] if isinstance(input, str) else input
resolved_files: Dict[str, Dict[str, str]] = {}
for element in input_list:
if isinstance(element, str) and _is_file_reference(element):
url = f"https://generativelanguage.googleapis.com/v1beta/{element}"
headers = {"x-goog-api-key": api_key}
response = await async_handler.get(url=url, headers=headers)
if response.status_code != 200:
raise Exception(
f"Error fetching file {element}: {response.status_code} {response.text}"
)
file_data = response.json()
resolved_files[element] = {
"mime_type": file_data.get("mimeType", ""),
"uri": file_data.get("uri", element),
}
return resolved_files
def batch_embeddings(
self,
model: str,
@@ -54,20 +136,6 @@ class GoogleBatchEmbeddings(VertexLLM):
custom_llm_provider=custom_llm_provider,
)
auth_header, url = self._get_token_and_url(
model=model,
auth_header=_auth_header,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="batch_embedding",
)
if client is None:
_params = {}
if timeout is not None:
@@ -83,9 +151,25 @@ class GoogleBatchEmbeddings(VertexLLM):
optional_params = optional_params or {}
### TRANSFORMATION ###
request_data = transform_openai_input_gemini_content(
input=input, model=model, optional_params=optional_params
is_multimodal = _is_multimodal_input(input)
use_embed_content = is_multimodal or (custom_llm_provider == "vertex_ai")
if use_embed_content:
mode = "embedding"
else:
mode = "batch_embedding"
auth_header, url = self._get_token_and_url(
model=model,
auth_header=_auth_header,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode=mode,
)
headers = {
@@ -93,14 +177,46 @@ class GoogleBatchEmbeddings(VertexLLM):
}
if auth_header is not None:
if isinstance(auth_header, dict):
# For Gemini with custom api_base: auth_header is {"x-goog-api-key": "..."}
headers.update(auth_header)
else:
# For Vertex AI: auth_header is a Bearer token string
headers["Authorization"] = f"Bearer {auth_header}"
if extra_headers is not None:
headers.update(extra_headers)
if aembedding is True:
return self.async_batch_embeddings( # type: ignore
model=model,
api_base=api_base,
url=url,
data=None,
model_response=model_response,
timeout=timeout,
headers=headers,
input=input,
use_embed_content=use_embed_content,
api_key=api_key,
optional_params=optional_params,
logging_obj=logging_obj,
)
### TRANSFORMATION (sync path) ###
if use_embed_content:
resolved_files = {}
if api_key:
resolved_files = self._resolve_file_references(
input=input, api_key=api_key, sync_handler=sync_handler
)
request_data = transform_openai_input_gemini_embed_content(
input=input,
model=model,
optional_params=optional_params,
resolved_files=resolved_files,
)
else:
request_data = transform_openai_input_gemini_content(
input=input, model=model, optional_params=optional_params
)
## LOGGING
logging_obj.pre_call(
input=input,
@@ -112,18 +228,6 @@ class GoogleBatchEmbeddings(VertexLLM):
},
)
if aembedding is True:
return self.async_batch_embeddings( # type: ignore
model=model,
api_base=api_base,
url=url,
data=request_data,
model_response=model_response,
timeout=timeout,
headers=headers,
input=input,
)
response = sync_handler.post(
url=url,
headers=headers,
@@ -134,26 +238,38 @@ class GoogleBatchEmbeddings(VertexLLM):
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)
if use_embed_content:
return process_embed_content_response(
input=input,
model_response=model_response,
model=model,
response_json=_json_response,
)
else:
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)
async def async_batch_embeddings(
self,
model: str,
api_base: Optional[str],
url: str,
data: VertexAIBatchEmbeddingsRequestBody,
data: Optional[Union[VertexAIBatchEmbeddingsRequestBody, dict]],
model_response: EmbeddingResponse,
input: EmbeddingInput,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
use_embed_content: bool = False,
api_key: Optional[str] = None,
optional_params: Optional[dict] = None,
logging_obj: Optional[Any] = None,
) -> EmbeddingResponse:
if client is None:
_params = {}
@@ -171,6 +287,36 @@ class GoogleBatchEmbeddings(VertexLLM):
else:
async_handler = client # type: ignore
### TRANSFORMATION (async path) ###
if use_embed_content:
resolved_files = {}
if api_key:
resolved_files = await self._async_resolve_file_references(
input=input, api_key=api_key, async_handler=async_handler
)
data = transform_openai_input_gemini_embed_content(
input=input,
model=model,
optional_params=optional_params or {},
resolved_files=resolved_files,
)
else:
data = transform_openai_input_gemini_content(
input=input, model=model, optional_params=optional_params or {}
)
## LOGGING
if logging_obj is not None:
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": url,
"headers": headers,
},
)
response = await async_handler.post(
url=url,
headers=headers,
@@ -181,11 +327,19 @@ class GoogleBatchEmbeddings(VertexLLM):
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)
if use_embed_content:
return process_embed_content_response(
input=input,
model_response=model_response,
model=model,
response_json=_json_response,
)
else:
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)
@@ -4,20 +4,142 @@ Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batc
Why separate file? Make it easy to see how transformation works
"""
from typing import List
from typing import Dict, List, Optional, Tuple
from litellm.types.utils import EmbeddingResponse
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
BlobType,
ContentType,
EmbedContentRequest,
FileDataType,
PartType,
VertexAIBatchEmbeddingsRequestBody,
VertexAIBatchEmbeddingsResponseObject,
)
from litellm.types.utils import Embedding, Usage
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
from litellm.utils import get_formatted_prompt, token_counter
SUPPORTED_EMBEDDING_MIME_TYPES = {
"image/png",
"image/jpeg",
"audio/mpeg",
"audio/wav",
"video/mp4",
"video/quicktime",
"application/pdf",
}
def _is_file_reference(s: str) -> bool:
"""Check if string is a Gemini file reference (files/...)."""
return isinstance(s, str) and s.startswith("files/")
def _is_gcs_url(s: str) -> bool:
"""Check if string is a GCS URL (gs://...)."""
return isinstance(s, str) and s.startswith("gs://")
def _infer_mime_type_from_gcs_url(gcs_url: str) -> str:
"""
Infer MIME type from GCS URL file extension.
Args:
gcs_url: GCS URL like gs://bucket/path/to/file.png
Returns:
str: Inferred MIME type
Raises:
ValueError: If file extension is not supported
"""
extension_to_mime = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".mp3": "audio/mpeg",
".wav": "audio/wav",
".mp4": "video/mp4",
".mov": "video/quicktime",
".pdf": "application/pdf",
}
gcs_url_lower = gcs_url.lower()
for ext, mime_type in extension_to_mime.items():
if gcs_url_lower.endswith(ext):
return mime_type
raise ValueError(
f"Unable to infer MIME type from GCS URL: {gcs_url}. "
f"Supported extensions: {', '.join(extension_to_mime.keys())}"
)
def _parse_data_url(data_url: str) -> Tuple[str, str]:
"""
Parse a data URL to extract the media type and base64 data.
Args:
data_url: Data URL in format: data:image/jpeg;base64,/9j/4AAQ...
Returns:
tuple: (media_type, base64_data)
media_type: e.g., "image/jpeg", "video/mp4", "audio/mpeg"
base64_data: The base64-encoded data without the prefix
Raises:
ValueError: If data URL format is invalid or MIME type is unsupported
"""
if not data_url.startswith("data:"):
raise ValueError(f"Invalid data URL format: {data_url[:50]}...")
if "," not in data_url:
raise ValueError(f"Invalid data URL format (missing comma): {data_url[:50]}...")
metadata, base64_data = data_url.split(",", 1)
metadata = metadata[5:]
if ";" in metadata:
media_type = metadata.split(";")[0]
else:
media_type = metadata
if media_type not in SUPPORTED_EMBEDDING_MIME_TYPES:
raise ValueError(
f"Unsupported MIME type for embedding: {media_type}. "
f"Supported types: {', '.join(sorted(SUPPORTED_EMBEDDING_MIME_TYPES))}"
)
return media_type, base64_data
def _is_multimodal_input(input: EmbeddingInput) -> bool:
"""
Check if the input contains multimodal data (data URIs, file references, or GCS URLs).
Args:
input: EmbeddingInput (str or List[str])
Returns:
bool: True if any element is a data URI, file reference, or GCS URL
"""
if isinstance(input, str):
input_list = [input]
else:
input_list = input
for element in input_list:
if isinstance(element, str):
if element.startswith("data:") and ";base64," in element:
return True
if _is_file_reference(element):
return True
if _is_gcs_url(element):
return True
return False
def transform_openai_input_gemini_content(
input: EmbeddingInput, model: str, optional_params: dict
@@ -26,12 +148,17 @@ def transform_openai_input_gemini_content(
The content to embed. Only the parts.text fields will be counted.
"""
gemini_model_name = "models/{}".format(model)
gemini_params = optional_params.copy()
if "dimensions" in gemini_params:
gemini_params["outputDimensionality"] = gemini_params.pop("dimensions")
requests: List[EmbedContentRequest] = []
if isinstance(input, str):
request = EmbedContentRequest(
model=gemini_model_name,
content=ContentType(parts=[PartType(text=input)]),
**optional_params
**gemini_params
)
requests.append(request)
else:
@@ -39,13 +166,119 @@ def transform_openai_input_gemini_content(
request = EmbedContentRequest(
model=gemini_model_name,
content=ContentType(parts=[PartType(text=i)]),
**optional_params
**gemini_params
)
requests.append(request)
return VertexAIBatchEmbeddingsRequestBody(requests=requests)
def transform_openai_input_gemini_embed_content(
input: EmbeddingInput,
model: str,
optional_params: dict,
resolved_files: Optional[Dict[str, Dict[str, str]]] = None,
) -> dict:
"""
Transform OpenAI embedding input to Gemini embedContent format (multimodal).
Args:
input: EmbeddingInput (str or List[str]) with text, data URIs, or file references
model: Model name
optional_params: Additional parameters (taskType, outputDimensionality, etc.)
resolved_files: Dict mapping file names (files/abc) to {mime_type, uri}
Returns:
dict: Gemini embedContent request body with content.parts
"""
resolved_files = resolved_files or {}
gemini_params = optional_params.copy()
if "dimensions" in gemini_params:
gemini_params["outputDimensionality"] = gemini_params.pop("dimensions")
input_list = [input] if isinstance(input, str) else input
parts: List[PartType] = []
for element in input_list:
if not isinstance(element, str):
raise ValueError(f"Unsupported input type: {type(element)}")
if element.startswith("data:") and ";base64," in element:
mime_type, base64_data = _parse_data_url(element)
blob: BlobType = {"mime_type": mime_type, "data": base64_data}
parts.append(PartType(inline_data=blob))
elif _is_gcs_url(element):
mime_type = _infer_mime_type_from_gcs_url(element)
file_data: FileDataType = {
"mime_type": mime_type,
"file_uri": element,
}
parts.append(PartType(file_data=file_data))
elif _is_file_reference(element):
if element not in resolved_files:
raise ValueError(f"File reference {element} not resolved")
file_info = resolved_files[element]
file_data_ref: FileDataType = {
"mime_type": file_info["mime_type"],
"file_uri": file_info["uri"],
}
parts.append(PartType(file_data=file_data_ref))
else:
parts.append(PartType(text=element))
request_body: dict = {
"content": ContentType(parts=parts),
**gemini_params,
}
return request_body
def process_embed_content_response(
input: EmbeddingInput,
model_response: EmbeddingResponse,
model: str,
response_json: dict,
) -> EmbeddingResponse:
"""
Process Gemini embedContent response (single embedding for multimodal input).
Args:
input: Original input
model_response: EmbeddingResponse to populate
model: Model name
response_json: Raw JSON response from embedContent endpoint
Returns:
EmbeddingResponse with single embedding
"""
if "embedding" not in response_json:
raise ValueError(f"embedContent response missing 'embedding' field: {response_json}")
embedding_data = response_json["embedding"]
openai_embedding = Embedding(
embedding=embedding_data["values"],
index=0,
object="embedding",
)
model_response.data = [openai_embedding]
model_response.model = model
if _is_multimodal_input(input):
prompt_tokens = 0
else:
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = token_counter(model=model, text=input_text)
model_response.usage = Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
)
return model_response
def process_response(
input: EmbeddingInput,
model_response: EmbeddingResponse,
+132 -2
View File
@@ -132,6 +132,7 @@ from litellm.utils import (
create_tokenizer,
get_api_key,
get_llm_provider,
get_model_info,
get_non_default_completion_params,
get_non_default_transcription_params,
get_optional_params_embeddings,
@@ -5194,13 +5195,37 @@ def embedding( # noqa: PLR0915
or get_secret_str("VERTEX_API_BASE")
)
if (
try:
model_info = get_model_info(model=model, custom_llm_provider="vertex_ai")
uses_embed_content = model_info.get("uses_embed_content", False)
except Exception:
uses_embed_content = False
if uses_embed_content:
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
input=input,
encoding=_get_encoding(),
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="vertex_ai",
api_key=None,
api_base=api_base,
client=client,
extra_headers=headers,
)
elif (
"image" in optional_params
or "video" in optional_params
or model
in vertex_multimodal_embedding.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
):
# multimodal embedding is supported on vertex httpx
response = vertex_multimodal_embedding.multimodal_embedding(
model=model,
input=input,
@@ -7575,6 +7600,111 @@ def stream_chunk_builder( # noqa: PLR0915
)
########## Token Counting API ##########
async def acount_tokens(
model: str,
messages: Optional[List[Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> "TokenCountResponse":
"""
Count tokens for a given model and messages using provider-specific APIs.
Routes to the appropriate provider's token counting API (OpenAI, Anthropic, etc.)
for exact token counts. Falls back to local tiktoken-based counting for unsupported providers.
Args:
model: The model identifier (e.g., "openai/gpt-4o", "anthropic/claude-3-5-sonnet-20241022")
messages: The messages to count tokens for (standard chat format)
tools: Optional tools/functions to include in token count
system: Optional system message/instructions
api_key: Optional API key (falls back to environment variable)
api_base: Optional custom API base URL
Returns:
TokenCountResponse with total_tokens and metadata
"""
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.types.utils import LlmProviders, TokenCountResponse
from litellm.utils import ProviderConfigManager
# Determine provider from model string
resolved_model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
get_llm_provider(
model=model,
api_base=api_base,
api_key=api_key,
)
)
# Use dynamic key/base if not explicitly provided
if api_key is None:
api_key = dynamic_api_key
if api_base is None:
api_base = dynamic_api_base
# Build deployment dict for the token counter
deployment: Dict[str, Any] = {
"litellm_params": {
"model": model,
"api_key": api_key,
"api_base": api_base,
}
}
# Try to get provider-specific token counter
try:
llm_provider_enum = LlmProviders(custom_llm_provider)
provider_model_info = ProviderConfigManager.get_provider_model_info(
model=model, provider=llm_provider_enum
)
if provider_model_info is not None:
token_counter_instance = provider_model_info.get_token_counter()
if (
token_counter_instance is not None
and token_counter_instance.should_use_token_counting_api(
custom_llm_provider
)
):
result = await token_counter_instance.count_tokens(
model_to_use=resolved_model,
messages=messages,
contents=None,
deployment=deployment,
request_model=model,
tools=tools,
system=system,
)
if result is not None and not result.error:
return result
except Exception as e:
verbose_logger.debug(
f"Provider token counting failed for model={model}, falling back to local: {e}"
)
# Fallback to local tiktoken-based token counting
fallback_messages = messages or []
if system and fallback_messages:
fallback_messages = [{"role": "system", "content": system}] + fallback_messages
local_count = litellm.token_counter(
model=model,
messages=fallback_messages,
tools=tools,
)
return TokenCountResponse(
total_tokens=local_count,
request_model=model,
model_used=resolved_model,
tokenizer_type="local_tokenizer",
)
# Cache for encoding to avoid repeated __getattr__ calls
_encoding_cache: Optional[Any] = None
File diff suppressed because it is too large Load Diff
+97 -1
View File
@@ -104,6 +104,50 @@ def encrypt_credentials(
value=client_secret,
new_encryption_key=encryption_key,
)
# AWS SigV4 credential fields
aws_access_key_id = credentials.get("aws_access_key_id")
if aws_access_key_id is not None:
credentials["aws_access_key_id"] = encrypt_value_helper(
value=aws_access_key_id,
new_encryption_key=encryption_key,
)
aws_secret_access_key = credentials.get("aws_secret_access_key")
if aws_secret_access_key is not None:
credentials["aws_secret_access_key"] = encrypt_value_helper(
value=aws_secret_access_key,
new_encryption_key=encryption_key,
)
aws_session_token = credentials.get("aws_session_token")
if aws_session_token is not None:
credentials["aws_session_token"] = encrypt_value_helper(
value=aws_session_token,
new_encryption_key=encryption_key,
)
# aws_region_name and aws_service_name are NOT secrets — stored as-is
return credentials
def decrypt_credentials(
credentials: MCPCredentials,
) -> MCPCredentials:
"""Decrypt all secret fields in an MCPCredentials dict using the global salt key."""
secret_fields = [
"auth_value",
"client_id",
"client_secret",
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
]
for field in secret_fields:
value = credentials.get(field)
if value is not None:
credentials[field] = decrypt_value_helper(
value=value,
key=field,
exception_type="debug",
return_original_value=True,
)
return credentials
@@ -354,9 +398,57 @@ async def update_mcp_server(
"""
Update a new mcp server record in the db
"""
import json
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
# Use helper to prepare data with proper JSON serialization
data_dict = _prepare_mcp_server_data(data)
# Pre-fetch existing record once if we need it for auth_type or credential logic
existing = None
has_credentials = "credentials" in data_dict and data_dict["credentials"] is not None
if data.auth_type or has_credentials:
existing = await prisma_client.db.litellm_mcpservertable.find_unique(
where={"server_id": data.server_id}
)
# Clear stale credentials when auth_type changes but no new credentials provided
if (
data.auth_type
and "credentials" not in data_dict
and existing
and existing.auth_type is not None
and existing.auth_type != data.auth_type
):
data_dict["credentials"] = None
# Merge credentials: preserve existing fields not present in the update.
# Without this, a partial credential update (e.g. changing only region)
# would wipe encrypted secrets that the UI cannot display back.
if "credentials" in data_dict and data_dict["credentials"] is not None:
if existing and existing.credentials:
# Only merge when auth_type is unchanged. Switching auth types
# (e.g. oauth2 → api_key) should replace credentials entirely
# to avoid stale secrets from the previous auth type lingering.
auth_type_unchanged = (
data.auth_type is None or data.auth_type == existing.auth_type
)
if auth_type_unchanged:
existing_creds = (
json.loads(existing.credentials)
if isinstance(existing.credentials, str)
else dict(existing.credentials)
)
new_creds = (
json.loads(data_dict["credentials"])
if isinstance(data_dict["credentials"], str)
else dict(data_dict["credentials"])
)
# New values override existing; existing keys not in update are preserved
merged = {**existing_creds, **new_creds}
data_dict["credentials"] = safe_dumps(merged)
# Add audit fields
data_dict["updated_by"] = touched_by
@@ -378,8 +470,12 @@ async def rotate_mcp_server_credentials_master_key(
continue
credentials_copy = dict(credentials)
encrypted_credentials = encrypt_credentials(
# Decrypt with current key first, then re-encrypt with new key
decrypted_credentials = decrypt_credentials(
credentials=cast(MCPCredentials, credentials_copy),
)
encrypted_credentials = encrypt_credentials(
credentials=decrypted_credentials,
encryption_key=new_master_key,
)
@@ -597,9 +597,10 @@ class MCPServerManager:
else:
client_secret_value = encrypted_client_secret
# TODO: Add AWS SigV4 credential decryption here when DB-stored
# SigV4 MCP servers are supported. Requires corresponding changes
# to encrypt_credentials() in db.py and MCPCredentials TypedDict.
# AWS SigV4 credential fields
aws_creds = self._extract_aws_credentials(
credentials_dict, credentials_are_encrypted
)
scopes: Optional[List[str]] = None
if credentials_dict:
@@ -679,6 +680,12 @@ class MCPServerManager:
is_byok=bool(getattr(mcp_server, "is_byok", False)),
byok_description=getattr(mcp_server, "byok_description", None) or [],
byok_api_key_help_url=getattr(mcp_server, "byok_api_key_help_url", None),
# AWS SigV4 fields
aws_access_key_id=aws_creds.get("aws_access_key_id"),
aws_secret_access_key=aws_creds.get("aws_secret_access_key"),
aws_session_token=aws_creds.get("aws_session_token"),
aws_region_name=aws_creds.get("aws_region_name"),
aws_service_name=aws_creds.get("aws_service_name"),
)
return new_server
@@ -1520,6 +1527,52 @@ class MCPServerManager:
return None
@staticmethod
def _decrypt_credential_field(
encrypted_value: Optional[str],
key: str,
credentials_are_encrypted: bool,
) -> Optional[str]:
"""Decrypt a single credential field, or return as-is if not encrypted."""
if not encrypted_value:
return None
if credentials_are_encrypted:
return decrypt_value_helper(
value=encrypted_value,
key=key,
exception_type="debug",
return_original_value=True,
)
return encrypted_value
def _extract_aws_credentials(
self,
credentials_dict: Optional[Dict[str, str]],
credentials_are_encrypted: bool,
) -> Dict[str, Optional[str]]:
"""Extract and decrypt AWS SigV4 credential fields from credentials dict."""
if not credentials_dict:
return {}
return {
"aws_access_key_id": self._decrypt_credential_field(
credentials_dict.get("aws_access_key_id"),
"aws_access_key_id",
credentials_are_encrypted,
),
"aws_secret_access_key": self._decrypt_credential_field(
credentials_dict.get("aws_secret_access_key"),
"aws_secret_access_key",
credentials_are_encrypted,
),
"aws_session_token": self._decrypt_credential_field(
credentials_dict.get("aws_session_token"),
"aws_session_token",
credentials_are_encrypted,
),
"aws_region_name": credentials_dict.get("aws_region_name"),
"aws_service_name": credentials_dict.get("aws_service_name"),
}
def _extract_scopes(self, scopes_value: Any) -> Optional[List[str]]:
if isinstance(scopes_value, str):
scopes = [s.strip() for s in scopes_value.split() if s.strip()]
@@ -93,25 +93,7 @@ def get_base_url(spec: Dict[str, Any], spec_path: Optional[str] = None) -> str:
"""Extract base URL from OpenAPI spec."""
# OpenAPI 3.x
if "servers" in spec and spec["servers"]:
server_url = spec["servers"][0]["url"]
# If the server URL is relative (starts with /), derive base from spec_path
if server_url.startswith("/") and spec_path:
if spec_path.startswith("http://") or spec_path.startswith("https://"):
# Extract base URL from spec_path (e.g., https://petstore3.swagger.io/api/v3/openapi.json)
# Combine domain with the relative server URL
from urllib.parse import urlparse
parsed = urlparse(spec_path)
base_domain = f"{parsed.scheme}://{parsed.netloc}"
full_base_url = base_domain + server_url
verbose_logger.info(
f"OpenAPI spec has relative server URL '{server_url}'. "
f"Deriving base from spec_path: {full_base_url}"
)
return full_base_url
return server_url
return spec["servers"][0]["url"]
# OpenAPI 2.x (Swagger)
elif "host" in spec:
scheme = spec.get("schemes", ["https"])[0]
@@ -718,7 +718,6 @@ if MCP_AVAILABLE:
Checks both the full tool name and unprefixed version (without server prefix).
This allows users to configure simple tool names regardless of prefixing.
Comparison is case-insensitive to handle OpenAPI operationIds that may be in camelCase.
Args:
tool_name: The tool name to check (may be prefixed like "server-tool_name")
@@ -731,15 +730,13 @@ if MCP_AVAILABLE:
split_server_prefix_from_name,
)
# Normalize filter list to lowercase for case-insensitive comparison
filter_list_lower = [f.lower() for f in filter_list]
if tool_name.lower() in filter_list_lower:
# Check if the full name is in the list
if tool_name in filter_list:
return True
# Check if the unprefixed name is in the list (case-insensitive)
# Check if the unprefixed name is in the list
unprefixed_name, _ = split_server_prefix_from_name(tool_name)
return unprefixed_name.lower() in filter_list_lower
return unprefixed_name in filter_list
def filter_tools_by_allowed_tools(
tools: List[MCPTool],

Some files were not shown because too many files have changed in this diff Show More