mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 12:48:57 +00:00
merge: resolve conflicts between main and litellm_oss_staging_03_11_2026
This commit is contained in:
@@ -1,12 +1,19 @@
|
||||
name: "LiteLLM CodeQL config"
|
||||
|
||||
# Exclude queries that produce result sets > 2 GiB on this codebase,
|
||||
# causing 49+ minute runs that fail and block CI resources.
|
||||
# Use security-extended suite instead of security-and-quality to avoid
|
||||
# result sets > 2 GiB on this codebase that cause fatal OOM failures.
|
||||
queries:
|
||||
- uses: security-extended
|
||||
|
||||
# These two queries are security queries included in security-extended that
|
||||
# individually produce result sets > 2 GiB on this codebase, causing fatal
|
||||
# OOM failures. Exclude them as a safety net until CI confirms they no longer
|
||||
# OOM; drop these exclusions in a follow-up once verified.
|
||||
query-filters:
|
||||
- exclude:
|
||||
id: py/clear-text-logging-sensitive-data # CWE-312/CleartextLogging.ql — result set > 2 GiB
|
||||
id: py/clear-text-logging-sensitive-data # CWE-312 — > 2 GiB result set
|
||||
- exclude:
|
||||
id: py/polynomial-redos # CWE-730/PolynomialReDoS.ql — result set > 2 GiB
|
||||
id: py/polynomial-redos # CWE-730 — > 2 GiB result set
|
||||
|
||||
paths-ignore:
|
||||
- tests
|
||||
|
||||
@@ -110,6 +110,13 @@ LiteLLM is a unified interface for 100+ LLM providers with two main components:
|
||||
### Proxy database access
|
||||
- **Do not write raw SQL** for proxy DB operations. Use Prisma model methods instead of `execute_raw` / `query_raw`.
|
||||
- Use the generated client: `prisma_client.db.<model>` (e.g. `litellm_tooltable`, `litellm_usertable`) with `.upsert()`, `.find_many()`, `.find_unique()`, `.update()`, `.update_many()` as appropriate. This avoids schema/client drift, keeps code testable with simple mocks, and matches patterns used in spend logs and other proxy code.
|
||||
- **No N+1 queries.** Never query the DB inside a loop. Batch-fetch with `{"in": ids}` and distribute in-memory.
|
||||
- **Batch writes.** Use `create_many`/`update_many`/`delete_many` instead of individual calls (these return counts only; `update_many`/`delete_many` no-op silently on missing rows). When multiple separate writes target the same table (e.g. in `batch_()`), order by primary key to avoid deadlocks.
|
||||
- **Push work to the DB.** Filter, sort, group, and aggregate in SQL, not Python. Verify Prisma generates the expected SQL — e.g. prefer `group_by` over `find_many(distinct=...)` which does client-side processing.
|
||||
- **Bound large result sets.** Prisma materializes full results in memory. For results over ~10 MB, paginate with `take`/`skip` or `cursor`/`take`, always with an explicit `order`. Prefer cursor-based pagination (`skip` is O(n)). Don't paginate naturally small result sets.
|
||||
- **Limit fetched columns on wide tables.** Use `select` to fetch only needed fields — returns a partial object, so downstream code must not access unselected fields.
|
||||
- **Check index coverage.** For new or modified queries, check `schema.prisma` for a supporting index. Prefer extending an existing index (e.g. `@@index([a])` → `@@index([a, b])`) over adding a new one, unless it's a `@@unique`. Only add indexes for large/frequent queries.
|
||||
- **Keep schema files in sync.** Apply schema changes to all `schema.prisma` copies (`schema.prisma`, `litellm/proxy/`, `litellm-proxy-extras/`, `litellm-js/spend-logs/` for SpendLogs) with a migration under `litellm-proxy-extras/litellm_proxy_extras/migrations/`.
|
||||
|
||||
### Enterprise Features
|
||||
- Enterprise-specific code in `enterprise/` directory
|
||||
|
||||
@@ -61,6 +61,20 @@ Create the name of the service account to use
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create the service account name used by migration jobs.
|
||||
When Helm hooks are enabled, pre-install/pre-upgrade hooks run before normal resources.
|
||||
If this chart is creating the ServiceAccount, it is not yet available for the hook job,
|
||||
so fall back to "default" (or an explicit override) to avoid a cyclic dependency.
|
||||
*/}}
|
||||
{{- define "litellm.migrationServiceAccountName" -}}
|
||||
{{- if and .Values.migrationJob.hooks.helm.enabled .Values.serviceAccount.create }}
|
||||
{{- default "default" .Values.migrationJob.serviceAccountName }}
|
||||
{{- else }}
|
||||
{{- include "litellm.serviceAccountName" . }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Get redis service name
|
||||
*/}}
|
||||
|
||||
@@ -34,7 +34,7 @@ spec:
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "litellm.serviceAccountName" . }}
|
||||
serviceAccountName: {{ include "litellm.migrationServiceAccountName" . }}
|
||||
{{- with .Values.migrationJob.extraInitContainers }}
|
||||
initContainers:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
|
||||
@@ -124,4 +124,67 @@ tests:
|
||||
- notContains:
|
||||
path: spec.template.spec.containers[0].env
|
||||
content:
|
||||
name: DATABASE_URL
|
||||
name: DATABASE_URL
|
||||
|
||||
- it: should use default service account for helm hooks when serviceAccount.create is true
|
||||
template: migrations-job.yaml
|
||||
set:
|
||||
migrationJob:
|
||||
enabled: true
|
||||
hooks:
|
||||
helm:
|
||||
enabled: true
|
||||
serviceAccount:
|
||||
create: true
|
||||
asserts:
|
||||
- equal:
|
||||
path: spec.template.spec.serviceAccountName
|
||||
value: default
|
||||
|
||||
- it: should use migrationJob.serviceAccountName override for helm hooks when serviceAccount.create is true
|
||||
template: migrations-job.yaml
|
||||
set:
|
||||
migrationJob:
|
||||
enabled: true
|
||||
serviceAccountName: migration-sa
|
||||
hooks:
|
||||
helm:
|
||||
enabled: true
|
||||
serviceAccount:
|
||||
create: true
|
||||
asserts:
|
||||
- equal:
|
||||
path: spec.template.spec.serviceAccountName
|
||||
value: migration-sa
|
||||
|
||||
- it: should use chart service account when helm hooks are disabled
|
||||
template: migrations-job.yaml
|
||||
set:
|
||||
migrationJob:
|
||||
enabled: true
|
||||
hooks:
|
||||
helm:
|
||||
enabled: false
|
||||
serviceAccount:
|
||||
create: true
|
||||
name: my-custom-sa
|
||||
asserts:
|
||||
- equal:
|
||||
path: spec.template.spec.serviceAccountName
|
||||
value: my-custom-sa
|
||||
|
||||
- it: should use pre-existing service account when helm hooks are enabled but serviceAccount.create is false
|
||||
template: migrations-job.yaml
|
||||
set:
|
||||
migrationJob:
|
||||
enabled: true
|
||||
hooks:
|
||||
helm:
|
||||
enabled: true
|
||||
serviceAccount:
|
||||
create: false
|
||||
name: pre-existing-sa
|
||||
asserts:
|
||||
- equal:
|
||||
path: spec.template.spec.serviceAccountName
|
||||
value: pre-existing-sa
|
||||
|
||||
@@ -309,6 +309,10 @@ migrationJob:
|
||||
retries: 3 # Number of retries for the Job in case of failure
|
||||
backoffLimit: 4 # Backoff limit for Job restarts
|
||||
disableSchemaUpdate: false # Skip schema migrations for specific environments. When True, the job will exit with code 0.
|
||||
# Optional service account for the migration job.
|
||||
# Only used when migrationJob.hooks.helm.enabled=true and serviceAccount.create=true.
|
||||
# In that case, pre-install/pre-upgrade hooks run before normal resources, so this defaults to "default".
|
||||
serviceAccountName: ""
|
||||
annotations: {}
|
||||
ttlSecondsAfterFinished: 120
|
||||
resources: {}
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
---
|
||||
slug: gemini_embedding_2_multimodal
|
||||
title: "Gemini Embedding 2 Preview: Multimodal Embeddings on LiteLLM"
|
||||
date: 2025-03-11T10:00:00
|
||||
authors:
|
||||
- name: Sameer Kankute
|
||||
title: SWE @ LiteLLM (LLM Translation)
|
||||
url: https://www.linkedin.com/in/sameer-kankute/
|
||||
image_url: https://pbs.twimg.com/profile_images/2001352686994907136/ONgNuSk5_400x400.jpg
|
||||
description: "Generate embeddings from text, images, audio, video, and PDFs with gemini-embedding-2-preview on LiteLLM via Gemini API and Vertex AI."
|
||||
tags: [gemini, embeddings, multimodal, vertex ai]
|
||||
hide_table_of_contents: false
|
||||
---
|
||||
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Gemini Embedding 2 Preview: Multimodal Embeddings
|
||||
|
||||
LiteLLM now supports **multimodal embeddings** with `gemini-embedding-2-preview`—generating a single embedding from a mix of text, images, audio, video, and PDF content. Available via both the **Gemini API** (API key) and **Vertex AI** (GCP credentials).
|
||||
|
||||
## Supported Input Types
|
||||
|
||||
| Modality | Supported Formats |
|
||||
|----------|-------------------|
|
||||
| **Text** | Plain text |
|
||||
| **Image** | PNG, JPEG |
|
||||
| **Audio** | MP3, WAV |
|
||||
| **Video** | MP4, MOV |
|
||||
| **Documents** | PDF |
|
||||
|
||||
## Input Formats
|
||||
|
||||
LiteLLM accepts three input formats for multimodal content:
|
||||
|
||||
1. **Data URIs** – Base64-encoded inline: `data:image/png;base64,<encoded_data>`
|
||||
2. **GCS URLs** – Cloud Storage paths (Vertex AI): `gs://bucket/path/to/file.png`
|
||||
3. **Gemini File References** – Pre-uploaded files (Gemini API): `files/abc123`
|
||||
|
||||
## Quick Start
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="gemini" label="Gemini API">
|
||||
|
||||
```python
|
||||
from litellm import embedding
|
||||
import os
|
||||
|
||||
os.environ["GEMINI_API_KEY"] = "your-api-key"
|
||||
|
||||
# Text + Image (base64)
|
||||
response = embedding(
|
||||
model="gemini/gemini-embedding-2-preview",
|
||||
input=[
|
||||
"The food was delicious and the waiter...",
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII"
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="vertex" label="Vertex AI">
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import embedding
|
||||
|
||||
litellm.vertex_project = "your-project-id"
|
||||
litellm.vertex_location = "us-central1"
|
||||
|
||||
# Text + Image (GCS URL)
|
||||
response = embedding(
|
||||
model="vertex_ai/gemini-embedding-2-preview",
|
||||
input=[
|
||||
"Describe this image",
|
||||
"gs://my-bucket/images/photo.png"
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="LiteLLM Proxy">
|
||||
|
||||
**1. Config (config.yaml)**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gemini-embedding-2-preview
|
||||
litellm_params:
|
||||
model: gemini/gemini-embedding-2-preview
|
||||
api_key: os.environ/GEMINI_API_KEY
|
||||
- model_name: vertex-gemini-embedding-2-preview
|
||||
litellm_params:
|
||||
model: vertex_ai/gemini-embedding-2-preview
|
||||
vertex_project: os.environ/VERTEXAI_PROJECT
|
||||
vertex_location: os.environ/VERTEXAI_LOCATION
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config config.yaml
|
||||
```
|
||||
|
||||
**3. Call embeddings**
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/embeddings \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gemini-embedding-2-preview",
|
||||
"input": [
|
||||
"The food was delicious and the waiter...",
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII"
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Input Format Examples
|
||||
|
||||
| Format | Example | Provider |
|
||||
|--------|---------|----------|
|
||||
| **Data URI** | `data:image/png;base64,...` | Gemini, Vertex AI |
|
||||
| **GCS URL** | `gs://bucket/path/image.png` | Vertex AI |
|
||||
| **File reference** | `files/abc123` | Gemini API only |
|
||||
|
||||
### Supported MIME Types for Data URIs
|
||||
|
||||
- **Images:** `image/png`, `image/jpeg`
|
||||
- **Audio:** `audio/mpeg`, `audio/wav`
|
||||
- **Video:** `video/mp4`, `video/quicktime`
|
||||
- **Documents:** `application/pdf`
|
||||
|
||||
### GCS URL MIME Inference
|
||||
|
||||
For Vertex AI, MIME types are inferred from file extensions:
|
||||
|
||||
- `.png` → `image/png`
|
||||
- `.jpg` / `.jpeg` → `image/jpeg`
|
||||
- `.mp3` → `audio/mpeg`
|
||||
- `.wav` → `audio/wav`
|
||||
- `.mp4` → `video/mp4`
|
||||
- `.mov` → `video/quicktime`
|
||||
- `.pdf` → `application/pdf`
|
||||
|
||||
## Optional Parameters
|
||||
|
||||
| Parameter | Description | Maps to |
|
||||
|-----------|-------------|---------|
|
||||
| `dimensions` | Output embedding size | `outputDimensionality` |
|
||||
|
||||
```python
|
||||
response = embedding(
|
||||
model="gemini/gemini-embedding-2-preview",
|
||||
input=["text to embed"],
|
||||
dimensions=768, # Optional: control output vector size
|
||||
)
|
||||
```
|
||||
@@ -138,6 +138,7 @@ The `/v1/messages/count_tokens` endpoint automatically routes to the appropriate
|
||||
| Provider | Token Counting Method |
|
||||
|----------|----------------------|
|
||||
| Anthropic | [Anthropic Token Counting API](https://docs.anthropic.com/en/docs/build-with-claude/token-counting) |
|
||||
| OpenAI | [OpenAI Responses API `/input_tokens`](https://platform.openai.com/docs/api-reference/responses/input-tokens) — see [Token Counting](./count_tokens.md) |
|
||||
| Vertex AI (Claude) | Vertex AI Partner Models Token Counter |
|
||||
| Bedrock (Claude) | AWS Bedrock CountTokens API |
|
||||
| Gemini | Google AI Studio countTokens API |
|
||||
|
||||
@@ -11,6 +11,7 @@ This endpoint supports various guardrail types including:
|
||||
- **Presidio** - PII detection and masking
|
||||
- **Bedrock** - AWS Bedrock guardrails for content moderation
|
||||
- **Lakera** - AI safety guardrails
|
||||
- **PANW Prisma AIRS** - Threat detection, DLP, and policy enforcement
|
||||
- **Custom guardrails** - User-defined guardrails
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -13,7 +13,7 @@ import TabItem from '@theme/TabItem';
|
||||
| Fallbacks | ✅ | Works between supported models |
|
||||
| Loadbalancing | ✅ | Works between supported models |
|
||||
| Guardrails | ✅ | Applies to output transcribed text (non-streaming only) |
|
||||
| Supported Providers | `openai`, `azure`, `vertex_ai`, `gemini`, `deepgram`, `groq`, `fireworks_ai`, `ovhcloud` | |
|
||||
| Supported Providers | `openai`, `azure`, `vertex_ai`, `gemini`, `deepgram`, `groq`, `fireworks_ai`, `ovhcloud`, `mistral` | |
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -126,6 +126,7 @@ transcript = client.audio.transcriptions.create(
|
||||
- [Fireworks AI](./providers/fireworks_ai.md#audio-transcription)
|
||||
- [Groq](./providers/groq.md#speech-to-text---whisper)
|
||||
- [Deepgram](./providers/deepgram.md)
|
||||
- [Mistral (Voxtral)](./providers/mistral.md#audio-transcription)
|
||||
- [OVHcloud AI Endpoints](./providers/ovhcloud.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -51,6 +51,28 @@ Here's what an example response looks like
|
||||
}
|
||||
```
|
||||
|
||||
## Native Finish Reason
|
||||
|
||||
LiteLLM maps all provider-specific `finish_reason` values to OpenAI-compatible values (`stop`, `length`, `tool_calls`, `function_call`, `content_filter`). When the original provider value differs from the mapped value, it is preserved in `provider_specific_fields["native_finish_reason"]`.
|
||||
|
||||
This is useful for agent loops that need to distinguish between different stop conditions (e.g., Gemini's `MALFORMED_FUNCTION_CALL` vs a normal `stop`).
|
||||
|
||||
```python
|
||||
response = completion(model="gemini/gemini-2.0-flash", messages=messages)
|
||||
|
||||
choice = response.choices[0]
|
||||
print(choice.finish_reason) # "stop" (OpenAI-compatible)
|
||||
|
||||
# Access the original provider value when it differs:
|
||||
if hasattr(choice, "provider_specific_fields") and choice.provider_specific_fields:
|
||||
native = choice.provider_specific_fields.get("native_finish_reason")
|
||||
if native == "MALFORMED_FUNCTION_CALL":
|
||||
# Handle malformed function call differently from a normal stop
|
||||
pass
|
||||
```
|
||||
|
||||
When the provider already returns an OpenAI-compatible value (e.g., `stop`), `native_finish_reason` is not set.
|
||||
|
||||
## Additional Attributes
|
||||
|
||||
You can also access information like latency.
|
||||
|
||||
@@ -115,6 +115,11 @@ print(response)
|
||||
|
||||
Web fetch is available on the following Anthropic API models:
|
||||
|
||||
- `claude-opus-4-6` (Claude Opus 4.6)
|
||||
- `claude-sonnet-4-6` (Claude Sonnet 4.6)
|
||||
- `claude-opus-4-5` (Claude Opus 4.5)
|
||||
- `claude-sonnet-4-5` (Claude Sonnet 4.5)
|
||||
- `claude-haiku-4-5` (Claude Haiku 4.5)
|
||||
- `claude-opus-4-1-20250805` (Claude Opus 4.1)
|
||||
- `claude-opus-4-20250514` (Claude Opus 4)
|
||||
- `claude-sonnet-4-20250514` (Claude Sonnet 4)
|
||||
|
||||
@@ -80,6 +80,36 @@ That's it! The provider is now available.
|
||||
}
|
||||
```
|
||||
|
||||
## Responses API Support
|
||||
|
||||
If your provider also supports the OpenAI Responses API (`/v1/responses`), add `supported_endpoints`:
|
||||
|
||||
```json
|
||||
{
|
||||
"your_provider": {
|
||||
"base_url": "https://api.yourprovider.com/v1",
|
||||
"api_key_env": "YOUR_PROVIDER_API_KEY",
|
||||
"supported_endpoints": ["/v1/chat/completions", "/v1/responses"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This enables `litellm.responses()` with zero additional code:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
response = litellm.responses(
|
||||
model="your_provider/model-name",
|
||||
input="Hello, what can you do?",
|
||||
)
|
||||
print(response.output)
|
||||
```
|
||||
|
||||
If `supported_endpoints` is omitted, it defaults to `[]`. Chat completions is always enabled for JSON providers regardless of this field.
|
||||
|
||||
The provider inherits all request/response handling from OpenAI's Responses API — streaming, tools, and all standard parameters work out of the box.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
@@ -89,11 +119,17 @@ import os
|
||||
# Set your API key
|
||||
os.environ["YOUR_PROVIDER_API_KEY"] = "your-key-here"
|
||||
|
||||
# Use the provider
|
||||
# Chat completions
|
||||
response = litellm.completion(
|
||||
model="your_provider/model-name",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
# Responses API (if supported_endpoints includes "/v1/responses")
|
||||
response = litellm.responses(
|
||||
model="your_provider/model-name",
|
||||
input="Hello",
|
||||
)
|
||||
```
|
||||
|
||||
## When to Use Python Instead
|
||||
@@ -105,7 +141,9 @@ Use a Python config class if you need:
|
||||
- Provider-specific streaming logic
|
||||
- Advanced tool calling modifications
|
||||
|
||||
For these cases, create a config class in `litellm/llms/your_provider/chat/transformation.py` that inherits from `OpenAIGPTConfig` or `OpenAILikeChatConfig`.
|
||||
For chat completions, create a config class in `litellm/llms/your_provider/chat/transformation.py` that inherits from `OpenAIGPTConfig` or `OpenAILikeChatConfig`.
|
||||
|
||||
For responses API with small overrides, inherit from `OpenAIResponsesAPIConfig` and override only what's needed. See `litellm/llms/perplexity/responses/transformation.py` for a minimal example (~40 lines vs 400+).
|
||||
|
||||
## Testing
|
||||
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Token Counting
|
||||
|
||||
## Overview
|
||||
|
||||
LiteLLM provides exact token counting by calling provider-specific token counting APIs. This gives you accurate token counts before sending requests, helping with cost estimation and context window management.
|
||||
|
||||
| Feature | Details |
|
||||
|---------|---------|
|
||||
| SDK Method | `litellm.acount_tokens()` |
|
||||
| Proxy Endpoints | `/v1/messages/count_tokens` (Anthropic format), `/v1/responses/input_tokens` (OpenAI format) |
|
||||
| Fallback | Local tiktoken-based counting for unsupported providers |
|
||||
|
||||
## Supported Providers
|
||||
|
||||
| Provider | Token Counting API | Format |
|
||||
|----------|-------------------|--------|
|
||||
| OpenAI | [Responses API `/input_tokens`](https://platform.openai.com/docs/api-reference/responses/input-tokens) | OpenAI Responses |
|
||||
| Anthropic | [Messages `/count_tokens`](https://docs.anthropic.com/en/docs/build-with-claude/token-counting) | Anthropic Messages |
|
||||
| Vertex AI (Claude) | Vertex AI Partner Models Token Counter | Anthropic Messages |
|
||||
| Bedrock (Claude) | AWS Bedrock CountTokens API | Anthropic Messages |
|
||||
| Gemini | Google AI Studio countTokens API | Anthropic Messages |
|
||||
| Vertex AI (Gemini) | Vertex AI countTokens API | Anthropic Messages |
|
||||
| Other providers | Local tiktoken fallback | N/A |
|
||||
|
||||
## SDK Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import litellm
|
||||
|
||||
async def main():
|
||||
# OpenAI
|
||||
result = await litellm.acount_tokens(
|
||||
model="openai/gpt-4o",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
)
|
||||
print(f"Token count: {result.total_tokens}")
|
||||
print(f"Tokenizer: {result.tokenizer_type}") # "openai_api"
|
||||
|
||||
# Anthropic
|
||||
result = await litellm.acount_tokens(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
)
|
||||
print(f"Token count: {result.total_tokens}")
|
||||
print(f"Tokenizer: {result.tokenizer_type}") # "anthropic_api"
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### With Tools and System Message
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import litellm
|
||||
|
||||
async def main():
|
||||
result = await litellm.acount_tokens(
|
||||
model="openai/gpt-4o",
|
||||
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}],
|
||||
system="You are a helpful weather assistant.",
|
||||
)
|
||||
print(f"Token count (with tools): {result.total_tokens}")
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### Response Format
|
||||
|
||||
`litellm.acount_tokens()` returns a `TokenCountResponse`:
|
||||
|
||||
```python
|
||||
TokenCountResponse(
|
||||
total_tokens=15, # Token count
|
||||
request_model="openai/gpt-4o", # Model requested
|
||||
model_used="gpt-4o", # Model used for counting
|
||||
tokenizer_type="openai_api", # "openai_api", "anthropic_api", "local_tokenizer"
|
||||
original_response={"input_tokens": 15}, # Raw API response
|
||||
error=False, # True if counting failed
|
||||
error_message=None, # Error details if failed
|
||||
)
|
||||
```
|
||||
|
||||
### Fallback Behavior
|
||||
|
||||
If a provider doesn't support a token counting API, or if the API key is missing, `acount_tokens()` automatically falls back to local tiktoken-based counting:
|
||||
|
||||
```python
|
||||
# Unsupported provider → automatic fallback
|
||||
result = await litellm.acount_tokens(
|
||||
model="together_ai/meta-llama/Llama-3-8b-chat-hf",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
print(result.tokenizer_type) # "local_tokenizer"
|
||||
```
|
||||
|
||||
## Proxy Usage
|
||||
|
||||
### OpenAI Format — `/v1/responses/input_tokens`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="curl" label="curl">
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/responses/input_tokens" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"input": "Hello, how are you?"
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="python" label="Python (httpx)">
|
||||
|
||||
```python
|
||||
import httpx
|
||||
|
||||
response = httpx.post(
|
||||
"http://localhost:4000/v1/responses/input_tokens",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer sk-1234"
|
||||
},
|
||||
json={
|
||||
"model": "gpt-4o",
|
||||
"input": "Hello, how are you?"
|
||||
}
|
||||
)
|
||||
|
||||
print(response.json())
|
||||
# {"input_tokens": 7}
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{"input_tokens": 7}
|
||||
```
|
||||
|
||||
### Anthropic Format — `/v1/messages/count_tokens`
|
||||
|
||||
See [Anthropic Token Counting](./anthropic_count_tokens.md) for full documentation.
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/messages/count_tokens" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Proxy Configuration
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
- model_name: claude-3-5-sonnet
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-5-sonnet-20241022
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
```
|
||||
@@ -514,6 +514,57 @@ All models listed [here](https://ai.google.dev/gemini-api/docs/models/gemini) ar
|
||||
| Model Name | Function Call |
|
||||
| :--- | :--- |
|
||||
| text-embedding-004 | `embedding(model="gemini/text-embedding-004", input)` |
|
||||
| gemini-embedding-2-preview | `embedding(model="gemini/gemini-embedding-2-preview", input)` | [Multimodal docs](#gemini-embedding-2-preview-multimodal) |
|
||||
|
||||
### Gemini Embedding 2 Preview (Multimodal)
|
||||
|
||||
`gemini-embedding-2-preview` supports **multimodal embeddings**—text, images, audio, video, and PDF in a single request. See [blog post](/blog/gemini_embedding_2_multimodal) for details.
|
||||
|
||||
**Input formats:**
|
||||
- **Data URIs:** `data:image/png;base64,<encoded_data>`
|
||||
- **Gemini file references:** `files/abc123` (pre-uploaded via Gemini Files API)
|
||||
|
||||
**Supported MIME types:** `image/png`, `image/jpeg`, `audio/mpeg`, `audio/wav`, `video/mp4`, `video/quicktime`, `application/pdf`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import embedding
|
||||
import os
|
||||
os.environ["GEMINI_API_KEY"] = ""
|
||||
|
||||
# Text + Image (base64)
|
||||
response = embedding(
|
||||
model="gemini/gemini-embedding-2-preview",
|
||||
input=[
|
||||
"The food was delicious and the waiter...",
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII"
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/embeddings \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gemini-embedding-2-preview",
|
||||
"input": [
|
||||
"The food was delicious and the waiter...",
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII"
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**Optional:** `dimensions` maps to Gemini's `outputDimensionality`.
|
||||
|
||||
|
||||
## Vertex AI Embedding Models
|
||||
|
||||
@@ -16,7 +16,7 @@ LiteLLM provides image editing functionality that maps to OpenAI's `/images/edit
|
||||
| Supported operations | Create image edits | Single and multiple images supported |
|
||||
| Supported LiteLLM SDK Versions | 1.63.8+ | Gemini support requires 1.79.3+ |
|
||||
| Supported LiteLLM Proxy Versions | 1.71.1+ | Gemini support requires 1.79.3+ |
|
||||
| Supported LLM providers | **OpenAI**, **Gemini (Google AI Studio)**, **Vertex AI**, **OpenRouter**, **Stability AI**, **AWS Bedrock (Stability)** | Gemini supports the new `gemini-2.5-flash-image` family. Vertex AI supports both Gemini and Imagen models. OpenRouter routes image edits through chat completions. Stability AI and Bedrock Stability support various image editing operations. |
|
||||
| Supported LLM providers | **OpenAI**, **Gemini (Google AI Studio)**, **Vertex AI**, **OpenRouter**, **Stability AI**, **AWS Bedrock (Stability)**, **Black Forest Labs** | Gemini supports the new `gemini-2.5-flash-image` family. Vertex AI supports both Gemini and Imagen models. OpenRouter routes image edits through chat completions. Stability AI and Bedrock Stability support various image editing operations. Black Forest Labs supports FLUX Kontext models. |
|
||||
|
||||
#### ⚡️See all supported models and providers at [models.litellm.ai](https://models.litellm.ai/)
|
||||
|
||||
@@ -199,6 +199,63 @@ for idx, image_obj in enumerate(response.data):
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="bfl" label="Black Forest Labs">
|
||||
|
||||
#### Basic Image Edit
|
||||
```python showLineNumbers title="Black Forest Labs Image Edit"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
os.environ["BFL_API_KEY"] = "your-api-key"
|
||||
|
||||
response = litellm.image_edit(
|
||||
model="black_forest_labs/flux-kontext-pro",
|
||||
image=open("original_image.png", "rb"),
|
||||
prompt="Add a green leaf to the scene",
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
#### Inpainting with Mask
|
||||
```python showLineNumbers title="Black Forest Labs Inpainting"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
os.environ["BFL_API_KEY"] = "your-api-key"
|
||||
|
||||
# Use flux-pro-1.0-fill for inpainting
|
||||
response = litellm.image_edit(
|
||||
model="black_forest_labs/flux-pro-1.0-fill",
|
||||
image=open("original_image.png", "rb"),
|
||||
mask=open("mask_image.png", "rb"),
|
||||
prompt="Replace with a garden",
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
#### Outpainting (Expand)
|
||||
```python showLineNumbers title="Black Forest Labs Outpainting"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
os.environ["BFL_API_KEY"] = "your-api-key"
|
||||
|
||||
# Use flux-pro-1.0-expand to extend image borders
|
||||
response = litellm.image_edit(
|
||||
model="black_forest_labs/flux-pro-1.0-expand",
|
||||
image=open("original_image.png", "rb"),
|
||||
prompt="Continue the scene with mountains",
|
||||
top=256,
|
||||
bottom=256,
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="vertex_ai" label="Vertex AI">
|
||||
|
||||
#### Basic Image Edit (Gemini)
|
||||
@@ -392,6 +449,35 @@ curl -X POST "http://0.0.0.0:4000/v1/images/edits" \
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="bfl" label="Black Forest Labs">
|
||||
|
||||
1. Add Black Forest Labs image edit models to your `config.yaml`:
|
||||
```yaml showLineNumbers title="Black Forest Labs Proxy Configuration"
|
||||
model_list:
|
||||
- model_name: bfl-kontext-pro
|
||||
litellm_params:
|
||||
model: black_forest_labs/flux-kontext-pro
|
||||
api_key: os.environ/BFL_API_KEY
|
||||
model_info:
|
||||
mode: image_edit
|
||||
```
|
||||
|
||||
2. Start the LiteLLM proxy server:
|
||||
```bash showLineNumbers title="Start LiteLLM Proxy Server"
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Make an image edit request:
|
||||
```bash showLineNumbers title="Black Forest Labs Proxy Image Edit"
|
||||
curl -X POST "http://0.0.0.0:4000/v1/images/edits" \
|
||||
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
|
||||
-F "model=bfl-kontext-pro" \
|
||||
-F "image=@original_image.png" \
|
||||
-F "prompt=Add a sunset in the background"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="vertex_ai" label="Vertex AI">
|
||||
|
||||
1. Add Vertex AI image edit models to your `config.yaml`:
|
||||
|
||||
@@ -15,7 +15,7 @@ import TabItem from '@theme/TabItem';
|
||||
| Fallbacks | ✅ | Works between supported models |
|
||||
| Loadbalancing | ✅ | Works between supported models |
|
||||
| Guardrails | ✅ | Applies to input prompts (non-streaming only) |
|
||||
| Supported Providers | OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock, Recraft, OpenRouter, Xinference, Nscale | |
|
||||
| Supported Providers | OpenAI, Azure, Google AI Studio, Vertex AI, AWS Bedrock, Black Forest Labs, Recraft, OpenRouter, Xinference, Nscale | |
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
||||
@@ -133,6 +133,21 @@ LiteLLM attempts [OAuth 2.0 Authorization Server Discovery](https://datatracker.
|
||||
|
||||
<br/>
|
||||
|
||||
### AWS SigV4 Authentication
|
||||
|
||||
For MCP servers hosted on [AWS Bedrock AgentCore](https://docs.aws.amazon.com/bedrock/latest/userguide/agentcore.html), select **AWS SigV4** as the authentication type. LiteLLM will sign every outgoing MCP request with your AWS credentials using [Signature Version 4](https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html).
|
||||
|
||||
<Image
|
||||
img={require('../img/mcp_aws_sigv4_ui.png')}
|
||||
style={{width: '80%', display: 'block', margin: '0'}}
|
||||
/>
|
||||
|
||||
Fill in your AWS region, service name (defaults to `bedrock-agentcore`), and optionally your AWS access key and secret. If credentials are omitted, LiteLLM falls back to the boto3 credential chain (IAM roles, environment variables, etc.).
|
||||
|
||||
[**See full SigV4 setup guide**](./mcp_aws_sigv4.md)
|
||||
|
||||
<br/>
|
||||
|
||||
### Static Headers
|
||||
|
||||
Sometimes your MCP server needs specific headers on every request. Maybe it's an API key, maybe it's a custom header the server expects. Instead of configuring auth, you can just set them directly.
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
import Image from '@theme/IdealImage';
|
||||
|
||||
# MCP - AWS SigV4 Auth
|
||||
|
||||
Use AWS SigV4 authentication to connect LiteLLM to MCP servers hosted on [AWS Bedrock AgentCore](https://docs.aws.amazon.com/bedrock/latest/userguide/agentcore.html).
|
||||
@@ -10,6 +14,36 @@ LiteLLM's `aws_sigv4` auth type handles this automatically: every outgoing MCP r
|
||||
|
||||
## Quick Start
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="ui" label="LiteLLM UI">
|
||||
|
||||
1. Navigate to **MCP Servers** and click **Add New MCP Server**
|
||||
2. Set the transport to **Streamable HTTP**
|
||||
3. Select **AWS SigV4** as the authentication type
|
||||
4. Fill in your AWS credentials:
|
||||
|
||||
<Image
|
||||
img={require('../img/mcp_aws_sigv4_ui.png')}
|
||||
style={{width: '80%', display: 'block', margin: '0'}}
|
||||
/>
|
||||
|
||||
<br/>
|
||||
|
||||
| Field | Required | Description |
|
||||
|-------|----------|-------------|
|
||||
| **AWS Region** | Yes | AWS region for SigV4 signing (e.g., `us-east-1`) |
|
||||
| **AWS Service Name** | No | Defaults to `bedrock-agentcore` |
|
||||
| **AWS Access Key ID** | No | Falls back to boto3 credential chain if blank |
|
||||
| **AWS Secret Access Key** | No | Required if Access Key ID is provided |
|
||||
| **AWS Session Token** | No | Only needed for temporary STS credentials |
|
||||
|
||||
Once created, LiteLLM will sign every outgoing MCP request with SigV4. The server's tools appear automatically in the MCP Tools list.
|
||||
|
||||
**Editing credentials:** When editing an existing SigV4 server, leave credential fields blank to keep the current values. Only fields you fill in will be updated.
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="config" label="config.yaml">
|
||||
|
||||
### 1. Set AWS credentials
|
||||
|
||||
```bash
|
||||
@@ -60,9 +94,12 @@ arn%3Aaws%3Abedrock-agentcore%3Aus-east-1%3A123456789012%3Aruntime%2Fmy-mcp-serv
|
||||
litellm --config config.yaml
|
||||
```
|
||||
|
||||
### 4. Use the MCP tools
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
Once started, your AgentCore MCP tools are available through LiteLLM like any other MCP server:
|
||||
## Use the MCP tools
|
||||
|
||||
Once configured, your AgentCore MCP tools are available through LiteLLM like any other MCP server:
|
||||
|
||||
```bash title="List available tools"
|
||||
curl http://localhost:4000/mcp-rest/tools/list \
|
||||
|
||||
@@ -86,4 +86,5 @@ MCP guardrails work with all LiteLLM-supported guardrail providers:
|
||||
- **Lakera**: Content moderation
|
||||
- **Aporia**: Custom guardrails
|
||||
- **Noma**: Noma Security
|
||||
- **PANW Prisma AIRS**: Prisma AIRS guardrails
|
||||
- **Custom**: Your own guardrail implementations
|
||||
@@ -13,6 +13,7 @@ Here's the full specification with all available fields:
|
||||
```json
|
||||
{
|
||||
"sample_spec": {
|
||||
"aliases": ["optional list of alternate names for this model, e.g. dated versions like sample_spec-20250101"],
|
||||
"code_interpreter_cost_per_session": 0.0,
|
||||
"computer_use_input_cost_per_1k_tokens": 0.0,
|
||||
"computer_use_output_cost_per_1k_tokens": 0.0,
|
||||
@@ -121,4 +122,28 @@ Here's the full specification with all available fields:
|
||||
}
|
||||
```
|
||||
|
||||
That's it! Your PR will be reviewed and merged.
|
||||
### Using Aliases
|
||||
|
||||
Many providers release the same model under multiple names — for example, a `latest` tag and a dated version like `claude-sonnet-4-5-20250929`. Instead of duplicating the entire entry, you can use the `aliases` field:
|
||||
|
||||
```json
|
||||
{
|
||||
"claude-sonnet-4-5": {
|
||||
"aliases": ["claude-sonnet-4-5-20250929"],
|
||||
"input_cost_per_token": 3e-06,
|
||||
"output_cost_per_token": 1.5e-05,
|
||||
"litellm_provider": "anthropic",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 64000,
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
At load time, each alias is expanded into a top-level entry sharing the same data as the canonical entry. The example above makes both `claude-sonnet-4-5` and `claude-sonnet-4-5-20250929` resolve with the same pricing and capabilities.
|
||||
|
||||
:::info
|
||||
This is different from [`model_alias_map`](../completion/model_alias.md), which is a runtime SDK/proxy feature for mapping user-facing model names to LiteLLM model identifiers. The `aliases` field here is for the model cost JSON only — it avoids duplicate entries for models that share identical pricing and capabilities.
|
||||
:::
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Black Forest Labs Image Generation
|
||||
|
||||
Black Forest Labs provides state-of-the-art text-to-image generation using their FLUX models.
|
||||
|
||||
## Overview
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Description | Black Forest Labs FLUX models for high-quality text-to-image generation |
|
||||
| Provider Route on LiteLLM | `black_forest_labs/` |
|
||||
| Provider Doc | [Black Forest Labs API ↗](https://docs.bfl.ai/) |
|
||||
| Supported Operations | [`/images/generations`](#image-generation) |
|
||||
|
||||
## Setup
|
||||
|
||||
### API Key
|
||||
|
||||
```python showLineNumbers
|
||||
import os
|
||||
|
||||
# Set your Black Forest Labs API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
```
|
||||
|
||||
Get your API key from [Black Forest Labs](https://blackforestlabs.ai/).
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Name | Description | Price |
|
||||
|------------|-------------|-------|
|
||||
| `black_forest_labs/flux-pro-1.1` | Fast & reliable standard generation | $0.04/image |
|
||||
| `black_forest_labs/flux-pro-1.1-ultra` | Ultra high-resolution (up to 4MP) | $0.06/image |
|
||||
| `black_forest_labs/flux-dev` | Development/open-source variant | $0.025/image |
|
||||
| `black_forest_labs/flux-pro` | Original pro model | $0.05/image |
|
||||
|
||||
## Image Generation
|
||||
|
||||
### Usage - LiteLLM Python SDK
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="basic" label="Basic Usage">
|
||||
|
||||
```python showLineNumbers title="Basic Image Generation"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
# Set your API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
|
||||
# Generate an image
|
||||
response = litellm.image_generation(
|
||||
model="black_forest_labs/flux-pro-1.1",
|
||||
prompt="A beautiful sunset over the ocean with sailing boats",
|
||||
)
|
||||
|
||||
# BFL returns URLs
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="async" label="Async Usage">
|
||||
|
||||
```python showLineNumbers title="Async Image Generation"
|
||||
import os
|
||||
import asyncio
|
||||
import litellm
|
||||
|
||||
# Set your API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
|
||||
async def generate_image():
|
||||
response = await litellm.aimage_generation(
|
||||
model="black_forest_labs/flux-pro-1.1",
|
||||
prompt="A futuristic city skyline at night",
|
||||
)
|
||||
print(response.data[0].url)
|
||||
|
||||
# Run the async function
|
||||
asyncio.run(generate_image())
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="size" label="Custom Size">
|
||||
|
||||
```python showLineNumbers title="Image Generation with Custom Size"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
# Set your API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
|
||||
# Generate with specific dimensions
|
||||
response = litellm.image_generation(
|
||||
model="black_forest_labs/flux-pro-1.1",
|
||||
prompt="A majestic mountain landscape",
|
||||
size="1792x1024", # Maps to width/height
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="ultra" label="Ultra High-Res">
|
||||
|
||||
```python showLineNumbers title="Ultra High Resolution with flux-pro-1.1-ultra"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
# Set your API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
|
||||
# Generate ultra high-resolution image
|
||||
response = litellm.image_generation(
|
||||
model="black_forest_labs/flux-pro-1.1-ultra",
|
||||
prompt="Detailed portrait of a fantasy character",
|
||||
size="2048x2048", # Up to 4MP supported
|
||||
quality="hd", # Maps to raw=True for natural look
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="advanced" label="Advanced Parameters">
|
||||
|
||||
```python showLineNumbers title="Advanced Image Generation with BFL Parameters"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
# Set your API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
|
||||
# Generate with BFL-specific parameters
|
||||
response = litellm.image_generation(
|
||||
model="black_forest_labs/flux-pro-1.1",
|
||||
prompt="A cute orange cat sitting on a windowsill",
|
||||
seed=42, # For reproducible results
|
||||
output_format="png", # png or jpeg
|
||||
safety_tolerance=2, # 0-6, higher = more permissive
|
||||
prompt_upsampling=True, # Enhance prompt for better results
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### Usage - LiteLLM Proxy Server
|
||||
|
||||
#### 1. Configure your config.yaml
|
||||
|
||||
```yaml showLineNumbers title="Black Forest Labs Image Generation Configuration"
|
||||
model_list:
|
||||
- model_name: flux-pro
|
||||
litellm_params:
|
||||
model: black_forest_labs/flux-pro-1.1
|
||||
api_key: os.environ/BFL_API_KEY
|
||||
model_info:
|
||||
mode: image_generation
|
||||
|
||||
- model_name: flux-ultra
|
||||
litellm_params:
|
||||
model: black_forest_labs/flux-pro-1.1-ultra
|
||||
api_key: os.environ/BFL_API_KEY
|
||||
model_info:
|
||||
mode: image_generation
|
||||
|
||||
- model_name: flux-dev
|
||||
litellm_params:
|
||||
model: black_forest_labs/flux-dev
|
||||
api_key: os.environ/BFL_API_KEY
|
||||
model_info:
|
||||
mode: image_generation
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
```
|
||||
|
||||
#### 2. Start LiteLLM Proxy Server
|
||||
|
||||
```bash showLineNumbers title="Start LiteLLM Proxy Server"
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
#### 3. Make image generation requests
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="openai-sdk" label="OpenAI SDK">
|
||||
|
||||
```python showLineNumbers title="Black Forest Labs via Proxy - OpenAI SDK"
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize client with your proxy URL
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:4000",
|
||||
api_key="sk-1234"
|
||||
)
|
||||
|
||||
# Generate image with FLUX Pro
|
||||
response = client.images.generate(
|
||||
model="flux-pro",
|
||||
prompt="A beautiful garden with colorful flowers",
|
||||
size="1024x1024",
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="cURL">
|
||||
|
||||
```bash showLineNumbers title="Black Forest Labs via Proxy - cURL"
|
||||
curl -X POST 'http://localhost:4000/v1/images/generations' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "flux-pro",
|
||||
"prompt": "A beautiful garden with colorful flowers",
|
||||
"size": "1024x1024"
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Supported Parameters
|
||||
|
||||
### OpenAI-Compatible Parameters
|
||||
|
||||
| Parameter | Type | Description | Mapping |
|
||||
|-----------|------|-------------|---------|
|
||||
| `prompt` | string | Text description of the image to generate | Direct |
|
||||
| `model` | string | The FLUX model to use | Direct |
|
||||
| `size` | string | Image dimensions (e.g., `1024x1024`) | Maps to `width` and `height` |
|
||||
| `n` | integer | Number of images (ultra model only, up to 4) | Maps to `num_images` |
|
||||
| `quality` | string | `hd` for natural look | Maps to `raw=True` for ultra |
|
||||
| `response_format` | string | `url` or `b64_json` | Direct |
|
||||
|
||||
### Black Forest Labs Specific Parameters
|
||||
|
||||
| Parameter | Type | Description | Default |
|
||||
|-----------|------|-------------|---------|
|
||||
| `width` | integer | Image width (256-1920, multiples of 16) | 1024 |
|
||||
| `height` | integer | Image height (256-1920, multiples of 16) | 1024 |
|
||||
| `aspect_ratio` | string | Alternative to width/height (e.g., `16:9`, `1:1`) | - |
|
||||
| `seed` | integer | Seed for reproducible results | Random |
|
||||
| `output_format` | string | Output format: `png` or `jpeg` | `png` |
|
||||
| `safety_tolerance` | integer | Safety filter tolerance (0-6, higher = more permissive) | 2 |
|
||||
| `prompt_upsampling` | boolean | Enhance prompt for better results | `false` |
|
||||
|
||||
### Ultra Model Specific Parameters
|
||||
|
||||
| Parameter | Type | Description | Default |
|
||||
|-----------|------|-------------|---------|
|
||||
| `raw` | boolean | Raw mode for more natural, less synthetic look | `false` |
|
||||
| `num_images` | integer | Number of images to generate (1-4) | 1 |
|
||||
|
||||
## How It Works
|
||||
|
||||
Black Forest Labs uses a polling-based API:
|
||||
|
||||
1. **Submit Request**: LiteLLM sends your prompt to BFL
|
||||
2. **Get Task ID**: BFL returns a task ID and polling URL
|
||||
3. **Poll for Result**: LiteLLM automatically polls until the image is ready
|
||||
4. **Return Result**: The generated image URL is returned
|
||||
|
||||
This polling is handled automatically by LiteLLM - you just call `image_generation()` and get the result.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Create an account at [Black Forest Labs](https://blackforestlabs.ai/)
|
||||
2. Get your API key from the dashboard
|
||||
3. Set your `BFL_API_KEY` environment variable
|
||||
4. Use `litellm.image_generation()` with any supported model
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Black Forest Labs Documentation](https://docs.bfl.ai/)
|
||||
- [Black Forest Labs Image Editing](./black_forest_labs_img_edit.md) - For editing existing images
|
||||
- [FLUX Model Information](https://blackforestlabs.ai/)
|
||||
@@ -0,0 +1,301 @@
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Black Forest Labs Image Editing
|
||||
|
||||
Black Forest Labs provides powerful image editing capabilities using their FLUX models to modify existing images based on text descriptions.
|
||||
|
||||
## Overview
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Description | Black Forest Labs Image Editing uses FLUX Kontext and other models to modify, inpaint, and expand images based on text prompts. |
|
||||
| Provider Route on LiteLLM | `black_forest_labs/` |
|
||||
| Provider Doc | [Black Forest Labs API ↗](https://docs.bfl.ai/) |
|
||||
| Supported Operations | [`/images/edits`](#image-editing) |
|
||||
|
||||
## Setup
|
||||
|
||||
### API Key
|
||||
|
||||
```python showLineNumbers
|
||||
import os
|
||||
|
||||
# Set your Black Forest Labs API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
```
|
||||
|
||||
Get your API key from [Black Forest Labs](https://blackforestlabs.ai/).
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Name | Description | Use Case |
|
||||
|------------|-------------|----------|
|
||||
| `black_forest_labs/flux-kontext-pro` | FLUX Kontext Pro - General image editing with prompts | General editing, style transfer |
|
||||
| `black_forest_labs/flux-kontext-max` | FLUX Kontext Max - Premium quality editing | High-quality edits |
|
||||
| `black_forest_labs/flux-pro-1.0-fill` | FLUX Pro Fill - Inpainting with mask | Remove/replace objects |
|
||||
| `black_forest_labs/flux-pro-1.0-expand` | FLUX Pro Expand - Outpainting | Expand image borders |
|
||||
|
||||
## Image Editing
|
||||
|
||||
### Usage - LiteLLM Python SDK
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="basic-edit" label="Basic Usage">
|
||||
|
||||
```python showLineNumbers title="Basic Image Editing"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
# Set your API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
|
||||
# Edit an image with a prompt
|
||||
response = litellm.image_edit(
|
||||
model="black_forest_labs/flux-kontext-pro",
|
||||
image=open("path/to/your/image.png", "rb"),
|
||||
prompt="Add a green leaf to the scene",
|
||||
)
|
||||
|
||||
# BFL returns URLs
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="async-edit" label="Async Usage">
|
||||
|
||||
```python showLineNumbers title="Async Image Editing"
|
||||
import os
|
||||
import asyncio
|
||||
import litellm
|
||||
|
||||
# Set your API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
|
||||
async def edit_image():
|
||||
response = await litellm.aimage_edit(
|
||||
model="black_forest_labs/flux-kontext-pro",
|
||||
image=open("path/to/your/image.png", "rb"),
|
||||
prompt="Make this image look like a watercolor painting",
|
||||
)
|
||||
print(response.data[0].url)
|
||||
|
||||
# Run the async function
|
||||
asyncio.run(edit_image())
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="inpainting" label="Inpainting (Fill)">
|
||||
|
||||
```python showLineNumbers title="Inpainting with Mask"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
# Set your API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
|
||||
# Use flux-pro-1.0-fill for inpainting
|
||||
response = litellm.image_edit(
|
||||
model="black_forest_labs/flux-pro-1.0-fill",
|
||||
image=open("path/to/your/image.png", "rb"),
|
||||
mask=open("path/to/mask.png", "rb"), # White areas will be edited
|
||||
prompt="Replace with a beautiful garden",
|
||||
steps=50, # BFL-specific parameter
|
||||
guidance=30, # BFL-specific parameter
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="outpainting" label="Outpainting (Expand)">
|
||||
|
||||
```python showLineNumbers title="Outpainting - Expand Image Borders"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
# Set your API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
|
||||
# Use flux-pro-1.0-expand to extend image borders
|
||||
response = litellm.image_edit(
|
||||
model="black_forest_labs/flux-pro-1.0-expand",
|
||||
image=open("path/to/your/image.png", "rb"),
|
||||
prompt="Continue the scene with a mountain landscape",
|
||||
top=256, # Expand 256 pixels at top
|
||||
bottom=256, # Expand 256 pixels at bottom
|
||||
left=128, # Expand 128 pixels at left
|
||||
right=128, # Expand 128 pixels at right
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="advanced" label="Advanced Parameters">
|
||||
|
||||
```python showLineNumbers title="Advanced Image Editing with BFL Parameters"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
# Set your API key
|
||||
os.environ["BFL_API_KEY"] = "your-api-key-here"
|
||||
|
||||
# Edit image with BFL-specific parameters
|
||||
response = litellm.image_edit(
|
||||
model="black_forest_labs/flux-kontext-pro",
|
||||
image=open("path/to/your/image.png", "rb"),
|
||||
prompt="Transform into cyberpunk style with neon lights",
|
||||
seed=42, # For reproducible results
|
||||
output_format="png", # png or jpeg
|
||||
safety_tolerance=2, # 0-6, higher = more permissive
|
||||
aspect_ratio="16:9", # Output aspect ratio
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### Usage - LiteLLM Proxy Server
|
||||
|
||||
#### 1. Configure your config.yaml
|
||||
|
||||
```yaml showLineNumbers title="Black Forest Labs Image Editing Configuration"
|
||||
model_list:
|
||||
- model_name: bfl-kontext-pro
|
||||
litellm_params:
|
||||
model: black_forest_labs/flux-kontext-pro
|
||||
api_key: os.environ/BFL_API_KEY
|
||||
model_info:
|
||||
mode: image_edit
|
||||
|
||||
- model_name: bfl-kontext-max
|
||||
litellm_params:
|
||||
model: black_forest_labs/flux-kontext-max
|
||||
api_key: os.environ/BFL_API_KEY
|
||||
model_info:
|
||||
mode: image_edit
|
||||
|
||||
- model_name: bfl-fill
|
||||
litellm_params:
|
||||
model: black_forest_labs/flux-pro-1.0-fill
|
||||
api_key: os.environ/BFL_API_KEY
|
||||
model_info:
|
||||
mode: image_edit
|
||||
|
||||
- model_name: bfl-expand
|
||||
litellm_params:
|
||||
model: black_forest_labs/flux-pro-1.0-expand
|
||||
api_key: os.environ/BFL_API_KEY
|
||||
model_info:
|
||||
mode: image_edit
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
```
|
||||
|
||||
#### 2. Start LiteLLM Proxy Server
|
||||
|
||||
```bash showLineNumbers title="Start LiteLLM Proxy Server"
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
#### 3. Make image editing requests
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="openai-sdk" label="OpenAI SDK">
|
||||
|
||||
```python showLineNumbers title="Black Forest Labs via Proxy - OpenAI SDK"
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize client with your proxy URL
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:4000",
|
||||
api_key="sk-1234"
|
||||
)
|
||||
|
||||
# Edit image with FLUX Kontext Pro
|
||||
response = client.images.edit(
|
||||
model="bfl-kontext-pro",
|
||||
image=open("path/to/your/image.png", "rb"),
|
||||
prompt="Add magical sparkles and fairy dust",
|
||||
)
|
||||
|
||||
print(response.data[0].url)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="cURL">
|
||||
|
||||
```bash showLineNumbers title="Black Forest Labs via Proxy - cURL"
|
||||
curl --location 'http://localhost:4000/v1/images/edits' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--form 'model="bfl-kontext-pro"' \
|
||||
--form 'prompt="Add a sunset in the background"' \
|
||||
--form 'image=@"path/to/your/image.png"'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Supported Parameters
|
||||
|
||||
### OpenAI-Compatible Parameters
|
||||
|
||||
| Parameter | Type | Description | Default |
|
||||
|-----------|------|-------------|---------|
|
||||
| `image` | file | The image file to edit | Required |
|
||||
| `prompt` | string | Text description of the desired changes | Required |
|
||||
| `model` | string | The FLUX model to use | Required |
|
||||
| `mask` | file | Mask image for inpainting (flux-pro-1.0-fill) | Optional |
|
||||
| `n` | integer | Number of images (BFL returns 1 per request) | `1` |
|
||||
| `size` | string | Maps to aspect_ratio | Optional |
|
||||
| `response_format` | string | `url` or `b64_json` | `url` |
|
||||
|
||||
### Black Forest Labs Specific Parameters
|
||||
|
||||
| Parameter | Type | Description | Default | Models |
|
||||
|-----------|------|-------------|---------|--------|
|
||||
| `seed` | integer | Seed for reproducible results | Random | All |
|
||||
| `output_format` | string | Output format: `png` or `jpeg` | `png` | All |
|
||||
| `safety_tolerance` | integer | Safety filter tolerance (0-6) | 2 | All |
|
||||
| `aspect_ratio` | string | Output aspect ratio (e.g., `16:9`, `1:1`) | Original | Kontext models |
|
||||
| `steps` | integer | Number of inference steps | Model default | Fill |
|
||||
| `guidance` | float | Guidance scale | Model default | Fill |
|
||||
| `grow_mask` | integer | Pixels to grow mask | 0 | Fill |
|
||||
| `top` | integer | Pixels to expand at top | 0 | Expand |
|
||||
| `bottom` | integer | Pixels to expand at bottom | 0 | Expand |
|
||||
| `left` | integer | Pixels to expand at left | 0 | Expand |
|
||||
| `right` | integer | Pixels to expand at right | 0 | Expand |
|
||||
|
||||
## How It Works
|
||||
|
||||
Black Forest Labs uses a polling-based API:
|
||||
|
||||
1. **Submit Request**: LiteLLM sends your image and prompt to BFL
|
||||
2. **Get Task ID**: BFL returns a task ID and polling URL
|
||||
3. **Poll for Result**: LiteLLM automatically polls until the image is ready
|
||||
4. **Return Result**: The generated image URL is returned
|
||||
|
||||
This polling is handled automatically by LiteLLM - you just call `image_edit()` and get the result.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Create an account at [Black Forest Labs](https://blackforestlabs.ai/)
|
||||
2. Get your API key from the dashboard
|
||||
3. Set your `BFL_API_KEY` environment variable
|
||||
4. Use `litellm.image_edit()` with any supported model
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Black Forest Labs Documentation](https://docs.bfl.ai/)
|
||||
- [FLUX Model Information](https://blackforestlabs.ai/)
|
||||
@@ -1562,13 +1562,18 @@ LiteLLM Supports the following image types passed in `url`
|
||||
|
||||
## Media Resolution Control (Images & Videos)
|
||||
|
||||
For Gemini 3+ models, LiteLLM supports per-part media resolution control using OpenAI's `detail` parameter. This allows you to specify different resolution levels for individual images and videos in your request, whether using `image_url` or `file` content types.
|
||||
LiteLLM supports OpenAI's `detail` parameter for specifying the image resolution when using Gemini models. The behavior differs between Gemini versions:
|
||||
|
||||
| Gemini Version | Resolution Control | Behavior |
|
||||
|----------------|-------------------|----------|
|
||||
| Gemini 3+ | Per-part | Each image/video can have its own `detail` setting |
|
||||
| Gemini 2.x (2.0, 2.5) | Global | The highest `detail` from all images is applied globally via `mediaResolution` in `generationConfig` |
|
||||
|
||||
**Supported `detail` values:**
|
||||
- `"low"` - Maps to `media_resolution: "low"` (280 tokens for images, 70 tokens per frame for videos)
|
||||
- `"medium"` - Maps to `media_resolution: "medium"`
|
||||
- `"high"` - Maps to `media_resolution: "high"` (1120 tokens for images)
|
||||
- `"ultra_high"` - Maps to `media_resolution: "ultra_high"`
|
||||
- `"low"` - Maps to `MEDIA_RESOLUTION_LOW` (280 tokens for images, 70 tokens per frame for videos)
|
||||
- `"medium"` - Maps to `MEDIA_RESOLUTION_MEDIUM`
|
||||
- `"high"` - Maps to `MEDIA_RESOLUTION_HIGH` (1120 tokens for images)
|
||||
- `"ultra_high"` - Maps to `MEDIA_RESOLUTION_ULTRA_HIGH`
|
||||
- `"auto"` or `None` - Model decides optimal resolution (no `media_resolution` set)
|
||||
|
||||
**Usage Examples:**
|
||||
@@ -1605,8 +1610,9 @@ messages = [
|
||||
}
|
||||
]
|
||||
|
||||
# Works with both Gemini 2.x and 3+
|
||||
response = completion(
|
||||
model="gemini/gemini-3-pro-preview",
|
||||
model="gemini/gemini-2.5-flash", # or gemini-3-pro-preview
|
||||
messages=messages,
|
||||
)
|
||||
```
|
||||
@@ -1647,7 +1653,9 @@ response = completion(
|
||||
</Tabs>
|
||||
|
||||
:::info
|
||||
**Per-Part Resolution:** Each image or video in your request can have its own `detail` setting, allowing mixed-resolution requests (e.g., a high-res chart alongside a low-res icon). This feature works with both `image_url` and `file` content types, and is only available for Gemini 3+ models.
|
||||
**Gemini 3+ Per-Part Resolution:** Each image or video can have its own `detail` setting, allowing mixed-resolution requests (e.g., a high-res chart alongside a low-res icon). This works with both `image_url` and `file` content types.
|
||||
|
||||
**Gemini 2.x Global Resolution:** When multiple images have different `detail` values, LiteLLM uses the highest resolution found and applies it globally via `mediaResolution` in `generationConfig` (e.g., if one image has `"low"` and another has `"high"`, all images will use `"high"`).
|
||||
:::
|
||||
|
||||
## Video Metadata Control
|
||||
|
||||
@@ -311,6 +311,79 @@ print(response)
|
||||
- **Model Compatibility**: Reasoning parameters only work with magistral models
|
||||
- **Backward Compatibility**: Non-magistral models will ignore reasoning parameters and work normally
|
||||
|
||||
## Audio Transcription
|
||||
|
||||
Use Mistral's Voxtral models for audio transcription via `litellm.transcription()`.
|
||||
|
||||
### SDK Usage
|
||||
|
||||
```python
|
||||
from litellm import transcription
|
||||
import os
|
||||
|
||||
os.environ["MISTRAL_API_KEY"] = ""
|
||||
|
||||
audio_file = open("path/to/audio.wav", "rb")
|
||||
|
||||
response = transcription(
|
||||
model="mistral/voxtral-mini-latest",
|
||||
file=audio_file,
|
||||
)
|
||||
|
||||
print(response.text)
|
||||
```
|
||||
|
||||
### With Optional Parameters
|
||||
|
||||
```python
|
||||
response = transcription(
|
||||
model="mistral/voxtral-mini-latest",
|
||||
file=audio_file,
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
response_format="json",
|
||||
)
|
||||
```
|
||||
|
||||
### Mistral-Specific Parameters
|
||||
|
||||
Mistral supports additional parameters beyond the OpenAI-compatible ones:
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `diarize` | `bool` | Enable speaker diarization |
|
||||
|
||||
```python
|
||||
response = transcription(
|
||||
model="mistral/voxtral-mini-latest",
|
||||
file=audio_file,
|
||||
diarize=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Usage with LiteLLM Proxy
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: voxtral
|
||||
litellm_params:
|
||||
model: mistral/voxtral-mini-latest
|
||||
api_key: os.environ/MISTRAL_API_KEY
|
||||
model_info:
|
||||
mode: audio_transcription
|
||||
```
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/v1/audio/transcriptions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--form 'file=@"audio.wav"' \
|
||||
--form 'model="voxtral"'
|
||||
```
|
||||
|
||||
## Sample Usage - Embedding
|
||||
```python
|
||||
from litellm import embedding
|
||||
|
||||
@@ -632,7 +632,9 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
|
||||
## OpenAI Chat Completion to Responses API Bridge
|
||||
|
||||
Call any Responses API model from OpenAI's `/chat/completions` endpoint.
|
||||
LiteLLM offers a chat completion to Responses API bridge. This lets you use the completion interface while calling the Responses API under the hood.
|
||||
|
||||
This is useful when you want to use [Responses API](https://platform.openai.com/docs/api-reference/responses) specific features (like built-in tools, web search preview, or code interpreter).
|
||||
|
||||
:::tip gpt-5.4 + reasoning_effort + function tools
|
||||
|
||||
@@ -649,12 +651,54 @@ response = litellm.completion(
|
||||
|
||||
:::
|
||||
|
||||
### When to use the `openai/responses/` prefix
|
||||
|
||||
Each model has a `mode` property defined in [`model_prices_and_context_window.json`](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) that determines which API endpoint it uses by default:
|
||||
|
||||
- **`mode: responses`** - Model automatically uses the Responses API
|
||||
- **`mode: chat`** - Model defaults to the Chat Completions API
|
||||
|
||||
**Models with `mode: responses`** (automatic Responses API):
|
||||
- `o3-deep-research`, `o4-mini-deep-research`
|
||||
- `o1-pro`, `o3-pro`
|
||||
- `gpt-5.1-codex`, `gpt-5.1-codex-mini`, `gpt-5.1-codex-max`
|
||||
- `codex-mini-latest`
|
||||
|
||||
**Models with `mode: chat`** (require `openai/responses/` prefix for built-in tools):
|
||||
- `gpt-4o`, `gpt-4o-mini`, `gpt-4.1`, `gpt-4.1-mini`
|
||||
- `gpt-5`, `gpt-5-mini`
|
||||
- `o3`, `o4-mini`
|
||||
|
||||
To use built-in tools like `web_search_preview` with `mode: chat` models, add the `openai/responses/` prefix:
|
||||
|
||||
```python
|
||||
# This will FAIL - gpt-4o has mode: chat, uses Chat Completions API
|
||||
response = litellm.completion(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "What is the weather in Paris today?"}],
|
||||
tools=[{"type": "web_search_preview"}], # Not supported in Chat Completions
|
||||
# ... other kwargs
|
||||
)
|
||||
|
||||
# This will WORK - prefix forces Responses API
|
||||
response = litellm.completion(
|
||||
model="openai/responses/gpt-4o",
|
||||
messages=[{"role": "user", "content": "What is the weather in Paris today?"}],
|
||||
tools=[{"type": "web_search_preview"}], # Supported in Responses API
|
||||
# ... other kwargs
|
||||
)
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
**Using a model with `mode: responses` (automatic):**
|
||||
|
||||
```python
|
||||
import litellm
|
||||
import os
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-1234"
|
||||
|
||||
@@ -668,6 +712,26 @@ response = litellm.completion(
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
**Using a model with `mode: chat` (requires prefix):**
|
||||
|
||||
```python
|
||||
import litellm
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-1234"
|
||||
|
||||
# Use the openai/responses/ prefix to enable built-in tools
|
||||
response = litellm.completion(
|
||||
model="openai/responses/gpt-4o",
|
||||
messages=[{"role": "user", "content": "What is the weather in Paris today?"}],
|
||||
tools=[
|
||||
{"type": "web_search_preview"},
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
@@ -675,10 +739,17 @@ print(response)
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: openai-model
|
||||
# Model with mode: responses (automatic)
|
||||
- model_name: o3-deep-research
|
||||
litellm_params:
|
||||
model: o3-deep-research-2025-06-26
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
# Model with mode: chat (use prefix for built-in tools)
|
||||
- model_name: gpt-4o-with-tools
|
||||
litellm_params:
|
||||
model: openai/responses/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
|
||||
2. Start the proxy
|
||||
@@ -693,15 +764,14 @@ litellm --config config.yaml
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "openai-model",
|
||||
-d '{
|
||||
"model": "gpt-4o-with-tools",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the capital of France?"}
|
||||
{"role": "user", "content": "What is the weather in Paris today?"}
|
||||
],
|
||||
"tools": [
|
||||
{"type": "web_search_preview"},
|
||||
{"type": "code_interpreter", "container": {"type": "auto"}},
|
||||
],
|
||||
{"type": "web_search_preview"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
||||
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
|
||||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||
| gemini-embedding-2-preview | `embedding(model="vertex_ai/gemini-embedding-2-preview", input)` | [Multimodal docs](#gemini-embedding-2-preview-multimodal) |
|
||||
| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/<your-model-id>", input)` |
|
||||
|
||||
### Supported OpenAI (Unified) Params
|
||||
@@ -257,6 +258,71 @@ model_list:
|
||||
|
||||
## **Multi-Modal Embeddings**
|
||||
|
||||
### Gemini Embedding 2 Preview (Multimodal)
|
||||
|
||||
`gemini-embedding-2-preview` supports **unified multimodal embeddings**—text, images, audio, video, and PDF in a single request. See [blog post](/blog/gemini_embedding_2_multimodal) for details.
|
||||
|
||||
**Input formats:**
|
||||
- **Data URIs:** `data:image/png;base64,<encoded_data>`
|
||||
- **GCS URLs:** `gs://bucket/path/to/file.png` (MIME type inferred from extension)
|
||||
|
||||
**Supported MIME types:** `image/png`, `image/jpeg`, `audio/mpeg`, `audio/wav`, `video/mp4`, `video/quicktime`, `application/pdf`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import embedding
|
||||
|
||||
litellm.vertex_project = "your-project-id"
|
||||
litellm.vertex_location = "us-central1"
|
||||
|
||||
# Text + Image (GCS URL)
|
||||
response = embedding(
|
||||
model="vertex_ai/gemini-embedding-2-preview",
|
||||
input=[
|
||||
"Describe this image",
|
||||
"gs://my-bucket/images/photo.png"
|
||||
],
|
||||
)
|
||||
|
||||
# Text + Image (base64)
|
||||
response = embedding(
|
||||
model="vertex_ai/gemini-embedding-2-preview",
|
||||
input=[
|
||||
"The food was delicious",
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII"
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="LiteLLM PROXY">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: vertex-gemini-embedding-2-preview
|
||||
litellm_params:
|
||||
model: vertex_ai/gemini-embedding-2-preview
|
||||
vertex_project: "your-project-id"
|
||||
vertex_location: "us-central1"
|
||||
```
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/embeddings \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "vertex-gemini-embedding-2-preview",
|
||||
"input": ["Describe this", "gs://bucket/image.png"]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### multimodalembedding@001 (Legacy)
|
||||
|
||||
Known Limitations:
|
||||
- Only supports 1 image / video / image per request
|
||||
|
||||
@@ -1,24 +1,15 @@
|
||||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# PANW Prisma AIRS
|
||||
|
||||
LiteLLM supports PANW Prisma AIRS (AI Runtime Security) guardrails via the [Prisma AIRS Scan API](https://pan.dev/prisma-airs/api/airuntimesecurity/airuntimesecurityapi//). This integration provides **Security-as-Code** for AI applications using Palo Alto Networks' AI security platform.
|
||||
LiteLLM supports PANW Prisma AIRS (AI Runtime Security) guardrails via the [Prisma AIRS Scan API](https://pan.dev/prisma-airs/api/airuntimesecurity/airuntimesecurityapi/). This integration provides Security-as-Code for AI applications using Palo Alto Networks' AI security platform.
|
||||
|
||||
## Features
|
||||
- **Prompt injection and malicious URL detection** — real-time scanning before or after LLM calls
|
||||
- **Data loss prevention (DLP)** — detect and block sensitive data in prompts and responses
|
||||
- **Sensitive content masking** — automatically mask PII, credit cards, SSNs instead of blocking
|
||||
- **MCP tool call scanning** — scan tool name and arguments on direct MCP tool invocations
|
||||
- **Configurable fail-open / fail-closed** — choose between maximum security or high availability
|
||||
|
||||
- ✅ **Real-time prompt injection detection**
|
||||
- ✅ **Malicious URL detection**
|
||||
- ✅ **Data loss prevention (DLP)**
|
||||
- ✅ **Sensitive content masking** - Automatically mask PII, credit cards, SSNs instead of blocking
|
||||
- ✅ **Comprehensive threat detection** for AI models and datasets
|
||||
- ✅ **Model-agnostic protection** across public and private models
|
||||
- ✅ **Synchronous scanning** with immediate response
|
||||
- ✅ **Configurable security profiles**
|
||||
- ✅ **Streaming support** - Real-time masking for streaming responses
|
||||
- ✅ **Multi-turn conversation tracking** - Automatic session grouping in Prisma AIRS SCM logs
|
||||
- ✅ **Configurable fail-open/fail-closed** - Choose between maximum security (block on API errors) or high availability (allow on transient errors)
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -32,7 +23,14 @@ For detailed setup instructions, see the [Prisma AIRS API Overview](https://docs
|
||||
|
||||
### 2. Define Guardrails on your LiteLLM config.yaml
|
||||
|
||||
Define your guardrails under the `guardrails` section:
|
||||
Set `api_base` to the regional endpoint for your Prisma AIRS deployment profile:
|
||||
|
||||
| Region | Endpoint |
|
||||
|--------|----------|
|
||||
| US | `https://service.api.aisecurity.paloaltonetworks.com` |
|
||||
| EU (Germany) | `https://service-de.api.aisecurity.paloaltonetworks.com` |
|
||||
| India | `https://service-in.api.aisecurity.paloaltonetworks.com` |
|
||||
| Singapore | `https://service-sg.api.aisecurity.paloaltonetworks.com` |
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
@@ -45,21 +43,15 @@ guardrails:
|
||||
- guardrail_name: "panw-prisma-airs-guardrail"
|
||||
litellm_params:
|
||||
guardrail: panw_prisma_airs
|
||||
mode: "pre_call" # Run before LLM call
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY # Your Prisma AIRS API key
|
||||
profile_name: os.environ/PANW_PRISMA_AIRS_PROFILE_NAME # Security profile from Strata Cloud Manager
|
||||
api_base: "https://service.api.aisecurity.paloaltonetworks.com"
|
||||
mode: "pre_call"
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
|
||||
profile_name: os.environ/PANW_PRISMA_AIRS_PROFILE_NAME
|
||||
api_base: "https://service.api.aisecurity.paloaltonetworks.com" # US — change to your region
|
||||
```
|
||||
|
||||
#### Supported values for `mode`
|
||||
|
||||
- `pre_call` Run **before** LLM call, on **input**
|
||||
- `post_call` Run **after** LLM call, on **input & output**
|
||||
- `during_call` Run **during** LLM call, on **input**. Same as `pre_call` but runs in parallel with LLM call
|
||||
|
||||
### 3. Start LiteLLM Gateway
|
||||
|
||||
```bash title="Set environment variables"
|
||||
```bash
|
||||
export PANW_PRISMA_AIRS_API_KEY="your-panw-api-key"
|
||||
export PANW_PRISMA_AIRS_PROFILE_NAME="your-security-profile"
|
||||
export OPENAI_API_KEY="sk-proj-..."
|
||||
@@ -69,15 +61,8 @@ export OPENAI_API_KEY="sk-proj-..."
|
||||
litellm --config config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
|
||||
### 4. Test Request
|
||||
|
||||
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="Blocked request" value="blocked">
|
||||
|
||||
Expect this to fail due to prompt injection attempt:
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
@@ -92,254 +77,57 @@ curl -i http://localhost:4000/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
Expected response on failure:
|
||||
Expected response when the guardrail blocks:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": {
|
||||
"error": "Violated PANW Prisma AIRS guardrail policy",
|
||||
"panw_response": {
|
||||
"action": "block",
|
||||
"category": "malicious",
|
||||
"profile_id": "03b32734-d06d-4bb7-a8df-ac5147630ce8",
|
||||
"profile_name": "dev-block-all-profile",
|
||||
"prompt_detected": {
|
||||
"dlp": false,
|
||||
"injection": true,
|
||||
"toxic_content": false,
|
||||
"url_cats": false
|
||||
},
|
||||
"report_id": "Rbd251eac-6e67-433b-b3ef-8eb42d2c7d2c",
|
||||
"response_detected": {
|
||||
"dlp": false,
|
||||
"toxic_content": false,
|
||||
"url_cats": false
|
||||
},
|
||||
"scan_id": "bd251eac-6e67-433b-b3ef-8eb42d2c7d2c",
|
||||
"tr_id": "string"
|
||||
}
|
||||
},
|
||||
"type": "None",
|
||||
"param": "None",
|
||||
"code": "400"
|
||||
"message": "Prompt blocked by PANW Prisma AI Security policy (Category: malicious)",
|
||||
"type": "guardrail_violation",
|
||||
"code": "panw_prisma_airs_blocked",
|
||||
"guardrail": "panw-prisma-airs-guardrail",
|
||||
"category": "malicious"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem label="Successful Call" value="allowed">
|
||||
LiteLLM wraps this detail in an endpoint-specific HTTP error envelope. Optional fields that may also appear: `scan_id`, `report_id`, `profile_name`, `profile_id`, `tr_id`, `prompt_detected`.
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-your-api-key" \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather like today?"}
|
||||
],
|
||||
"guardrails": ["panw-prisma-airs-guardrail"]
|
||||
}'
|
||||
```
|
||||
On success, the guardrail name appears in the `x-litellm-applied-guardrails` response header.
|
||||
|
||||
Expected successful response:
|
||||
## Configuration
|
||||
|
||||
```json
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "I don't have access to real-time weather data, but I can help you find weather information through various weather services or apps...",
|
||||
"role": "assistant",
|
||||
"tool_calls": null,
|
||||
"function_call": null,
|
||||
"annotations": []
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1736028456,
|
||||
"id": "chatcmpl-AqQj8example",
|
||||
"model": "gpt-4o",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"completion_tokens": 25,
|
||||
"prompt_tokens": 12,
|
||||
"total_tokens": 37
|
||||
},
|
||||
"x-litellm-panw-scan": {
|
||||
"action": "allow",
|
||||
"category": "benign",
|
||||
"profile_id": "03b32734-d06d-4bb7-a8df-ac5147630ce8",
|
||||
"profile_name": "dev-block-all-profile",
|
||||
"prompt_detected": {
|
||||
"dlp": false,
|
||||
"injection": false,
|
||||
"toxic_content": false,
|
||||
"url_cats": false
|
||||
},
|
||||
"report_id": "Rbd251eac-6e67-433b-b3ef-8eb42d2c7d2c",
|
||||
"response_detected": {
|
||||
"dlp": false,
|
||||
"toxic_content": false,
|
||||
"url_cats": false
|
||||
},
|
||||
"scan_id": "bd251eac-6e67-433b-b3ef-8eb42d2c7d2c",
|
||||
"tr_id": "string"
|
||||
}
|
||||
}
|
||||
```
|
||||
### Supported Modes
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
| Mode | Timing | What is scanned |
|
||||
|------|--------|-----------------|
|
||||
| `pre_call` | Before LLM call | Request input |
|
||||
| `during_call` | Parallel with LLM call | Request input |
|
||||
| `post_call` | After LLM call | Response output |
|
||||
| `pre_mcp_call` | Before MCP tool execution | MCP tool input |
|
||||
| `during_mcp_call` | Parallel with MCP tool execution | MCP tool input |
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
### Configuration Parameters
|
||||
|
||||
| Parameter | Required | Description | Default |
|
||||
|-----------|----------|-------------|---------|
|
||||
| `api_key` | Yes | Your PANW Prisma AIRS API key from Strata Cloud Manager | - |
|
||||
| `profile_name` | No | Security profile name configured in Strata Cloud Manager. Optional if API key has linked profile | - |
|
||||
| `app_name` | No | Application identifier for tracking in Prisma AIRS analytics (will be prefixed with "LiteLLM-") | `LiteLLM` |
|
||||
| `api_base` | No | Regional API endpoint (see [Regional Endpoints](#regional-endpoints) below) | `https://service.api.aisecurity.paloaltonetworks.com` (US) |
|
||||
| `mode` | No | When to run the guardrail | `pre_call` |
|
||||
| `fallback_on_error` | No | Action when PANW API is unavailable: `"block"` (fail-closed, default) or `"allow"` (fail-open). Config errors always block. | `block` |
|
||||
| `timeout` | No | PANW API call timeout in seconds (1-60) | `10.0` |
|
||||
| `violation_message_template` | No | Custom template for error message when request is blocked. Supports `{guardrail_name}`, `{category}`, `{action_type}`, `{default_message}` placeholders. | - |
|
||||
| `app_name` | No | Application identifier for tracking in Prisma AIRS analytics (prefixed with "LiteLLM-") | `LiteLLM` |
|
||||
| `api_base` | No | Regional API endpoint. US: `https://service.api.aisecurity.paloaltonetworks.com`, EU: `https://service-de.api.aisecurity.paloaltonetworks.com`, India: `https://service-in.api.aisecurity.paloaltonetworks.com`, Singapore: `https://service-sg.api.aisecurity.paloaltonetworks.com` | US |
|
||||
| `mode` | No | When to run the guardrail (see mode table above) | `pre_call` |
|
||||
| `fallback_on_error` | No | Action when PANW API is unavailable: `"block"` (fail-closed) or `"allow"` (fail-open). Config errors always block. | `block` |
|
||||
| `timeout` | No | PANW API call timeout in seconds (recommended: 1-60) | `10.0` |
|
||||
| `violation_message_template` | No | Custom template for blocked requests. Supports `{guardrail_name}`, `{category}`, `{action_type}`, `{default_message}` placeholders. | - |
|
||||
| `mask_request_content` | No | Mask sensitive data in prompts instead of blocking | `false` |
|
||||
| `mask_response_content` | No | Mask sensitive data in responses instead of blocking | `false` |
|
||||
| `mask_on_block` | No | Backwards-compatible flag that enables both request and response masking | `false` |
|
||||
| `experimental_use_latest_role_message_only` | No | Anthropic `/v1/messages` only. When unset: scans only latest user message on request side. Set `false` to scan all user/system/developer messages. Non-Anthropic unaffected. | Unset (true for Anthropic) |
|
||||
|
||||
### Regional Endpoints
|
||||
Use the regional `api_base` that matches your Prisma AIRS deployment profile region for lower latency and data residency compliance.
|
||||
|
||||
PANW Prisma AIRS supports multiple regional endpoints based on your deployment profile region:
|
||||
|
||||
| Region | API Base URL |
|
||||
|--------|--------------|
|
||||
| **US** (default) | `https://service.api.aisecurity.paloaltonetworks.com` |
|
||||
| **EU (Germany)** | `https://service-de.api.aisecurity.paloaltonetworks.com` |
|
||||
| **India** | `https://service-in.api.aisecurity.paloaltonetworks.com` |
|
||||
|
||||
**Example configuration for EU region:**
|
||||
|
||||
```yaml
|
||||
guardrails:
|
||||
- guardrail_name: "panw-eu"
|
||||
litellm_params:
|
||||
guardrail: panw_prisma_airs
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
|
||||
api_base: "https://service-de.api.aisecurity.paloaltonetworks.com"
|
||||
profile_name: "production"
|
||||
```
|
||||
|
||||
:::tip Region Selection
|
||||
Use the regional endpoint that matches your Prisma AIRS deployment profile region configured in Strata Cloud Manager. Using the correct region ensures:
|
||||
- Lower latency (requests stay in-region)
|
||||
- Compliance with data residency requirements
|
||||
- Optimal performance
|
||||
:::
|
||||
|
||||
## Per-Request Metadata Overrides
|
||||
|
||||
You can override guardrail settings on a per-request basis using the `metadata` field:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [...],
|
||||
"metadata": {
|
||||
"profile_name": "dev-allow-all", // Override profile name
|
||||
"profile_id": "uuid-here", // Override profile ID (takes precedence)
|
||||
"user_ip": "192.168.1.100", // Track user IP
|
||||
"app_name": "MyApp" // Custom app name (becomes "LiteLLM-MyApp")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Supported Metadata Fields:**
|
||||
|
||||
| Field | Description | Priority |
|
||||
|-------|-------------|----------|
|
||||
| `profile_name` | PANW AI security profile name | Per-request > config |
|
||||
| `profile_id` | PANW AI security profile ID (takes precedence over profile_name) | Per-request only |
|
||||
| `user_ip` | User IP address for tracking in Prisma AIRS | Per-request only |
|
||||
| `app_name` | Application identifier (prefixed with "LiteLLM-") | Per-request > config > "LiteLLM" |
|
||||
| `app_user` | Custom user identifier for tracking in Prisma AIRS | `app_user` > `user` > "litellm_user" |
|
||||
|
||||
:::info Profile Resolution
|
||||
- If both `profile_id` and `profile_name` are provided, PANW API uses `profile_id` (it takes precedence)
|
||||
- If no profile is specified in metadata, uses the config `profile_name`
|
||||
- If no profile is specified at all, PANW API will use the profile linked to your API key in Strata Cloud Manager
|
||||
- **Note:** If your API key is not linked to a profile, you must provide `profile_name` or `profile_id`
|
||||
:::
|
||||
|
||||
## Multi-Turn Conversation Tracking
|
||||
|
||||
PANW Prisma AIRS automatically tracks multi-turn conversations using LiteLLM's `litellm_trace_id`. This enables you to:
|
||||
|
||||
- **Group related requests** - All requests in a conversation share the same AI Session ID in Prisma AIRS SCM logs
|
||||
- **Track conversation context** - See the full history of prompts and responses for a user session
|
||||
- **Analyze attack patterns** - Identify sophisticated multi-turn attacks across conversation history
|
||||
|
||||
### How It Works
|
||||
|
||||
LiteLLM automatically generates a unique `litellm_trace_id` for each conversation session. The PANW guardrail uses this as the PANW transaction ID (which maps to "AI Session ID" in Strata Cloud Manager):
|
||||
|
||||
```
|
||||
Conversation Session: litellm_trace_id = "abc-123-def-456"
|
||||
|
||||
Turn 1 (User): "What's the capital of France?"
|
||||
→ Scan ID: scan_001 | Prisma AIRS AI Session ID: abc-123-def-456
|
||||
|
||||
Turn 2 (Assistant): "Paris is the capital of France."
|
||||
→ Scan ID: scan_002 | Prisma AIRS AI Session ID: abc-123-def-456
|
||||
|
||||
Turn 3 (User): "What's the population?"
|
||||
→ Scan ID: scan_003 | Prisma AIRS AI Session ID: abc-123-def-456
|
||||
|
||||
Turn 4 (Assistant): "Paris has approximately 2.1 million residents."
|
||||
→ Scan ID: scan_004 | Prisma AIRS AI Session ID: abc-123-def-456
|
||||
```
|
||||
|
||||
All scans appear under the same AI Session ID in Prisma AIRS logs, making it easy to:
|
||||
- Review complete conversation history (all 4 turns grouped together)
|
||||
- Identify patterns across multiple turns
|
||||
- Correlate security events within a session
|
||||
- Track the flow of user prompts and AI responses
|
||||
|
||||
### Session Tracking
|
||||
|
||||
LiteLLM automatically generates a unique `litellm_trace_id` for each request, which the PANW guardrail uses as the AI Session ID in Strata Cloud Manager. All prompt and response scans for a request are automatically grouped under the same session.
|
||||
|
||||
#### Custom Session IDs (Per-App Tracking)
|
||||
|
||||
You can provide your own `litellm_trace_id` to track sessions on a per-app or per-conversation basis:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "capital of France"}],
|
||||
"litellm_trace_id": "my-app-session-123", # Custom AI Session ID
|
||||
"metadata": {
|
||||
"profile_name": "dev-allow-all-profile", # Override security profile
|
||||
"user_ip": "192.168.1.1", # Track user IP
|
||||
"app_name": "eng" # Custom app identifier
|
||||
},
|
||||
"guardrails": ["panw-prisma-airs-pre-guard", "panw-prisma-airs-post-guard"]
|
||||
}'
|
||||
```
|
||||
|
||||
**Result in PANW SCM:**
|
||||
- AI Session ID: `my-app-session-123`
|
||||
- All prompt and response scans will be grouped under this custom session ID
|
||||
- Perfect for tracking multi-turn conversations or per-application sessions
|
||||
|
||||
:::tip Viewing Sessions in Prisma AIRS SCM Logs
|
||||
In Strata Cloud Manager, navigate to **AI Runtime > Sessions** to view all AI Session IDs and their associated scans. Click on a session to see the complete conversation history with security analysis.
|
||||
:::
|
||||
|
||||
## Environment Variables
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
export PANW_PRISMA_AIRS_API_KEY="your-panw-api-key"
|
||||
@@ -348,12 +136,31 @@ export PANW_PRISMA_AIRS_PROFILE_NAME="your-security-profile"
|
||||
export PANW_PRISMA_AIRS_API_BASE="https://custom-endpoint.com"
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
### Per-Request Metadata Overrides
|
||||
|
||||
| Field | Description | Priority |
|
||||
|-------|-------------|----------|
|
||||
| `profile_name` | PANW AI security profile name | Per-request > config |
|
||||
| `profile_id` | PANW AI security profile ID (takes precedence over `profile_name`) | Per-request only |
|
||||
| `user_ip` | User IP address for tracking in Prisma AIRS | Per-request only |
|
||||
| `app_name` | Application identifier (prefixed with "LiteLLM-") | Per-request > config > "LiteLLM" |
|
||||
| `app_user` | Custom user identifier for tracking in Prisma AIRS | `app_user` > `user` > "litellm_user" |
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [...],
|
||||
"metadata": {
|
||||
"profile_name": "dev-allow-all",
|
||||
"profile_id": "uuid-here",
|
||||
"user_ip": "192.168.1.100",
|
||||
"app_name": "MyApp"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple Security Profiles
|
||||
|
||||
You can configure different security profiles for different use cases:
|
||||
|
||||
```yaml
|
||||
guardrails:
|
||||
- guardrail_name: "panw-strict-security"
|
||||
@@ -361,126 +168,40 @@ guardrails:
|
||||
guardrail: panw_prisma_airs
|
||||
mode: "pre_call"
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
|
||||
profile_name: "strict-policy" # High security profile
|
||||
|
||||
- guardrail_name: "panw-permissive-security"
|
||||
profile_name: "strict-policy"
|
||||
|
||||
- guardrail_name: "panw-permissive-security"
|
||||
litellm_params:
|
||||
guardrail: panw_prisma_airs
|
||||
mode: "post_call"
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
|
||||
profile_name: "permissive-policy" # Lower security profile
|
||||
profile_name: "permissive-policy"
|
||||
```
|
||||
|
||||
### Multiple API Keys (Multi-Tenant)
|
||||
|
||||
For multi-tenant deployments where different customers need different PANW API keys, create separate guardrail instances:
|
||||
|
||||
```yaml
|
||||
guardrails:
|
||||
- guardrail_name: "panw-customer-a"
|
||||
litellm_params:
|
||||
guardrail: panw_prisma_airs
|
||||
mode: "pre_call"
|
||||
api_key: os.environ/PANW_CUSTOMER_A_KEY # Linked to Customer A profile in SCM
|
||||
|
||||
- guardrail_name: "panw-customer-b"
|
||||
litellm_params:
|
||||
guardrail: panw_prisma_airs
|
||||
mode: "pre_call"
|
||||
api_key: os.environ/PANW_CUSTOMER_B_KEY # Linked to Customer B profile in SCM
|
||||
```
|
||||
|
||||
Then route requests to the appropriate guardrail:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"guardrails": ["panw-customer-a"]
|
||||
}'
|
||||
```
|
||||
|
||||
**Use Cases:**
|
||||
- **Multi-tenant deployments**: Different customers with different security policies
|
||||
- **Environment-specific policies**: Dev/staging/prod with different API keys and profiles
|
||||
- **A/B testing**: Compare different security profiles side-by-side
|
||||
|
||||
### Content Masking
|
||||
|
||||
PANW Prisma AIRS can automatically mask sensitive content (PII, credit cards, SSNs, etc.) instead of blocking requests. This allows your application to continue functioning while protecting sensitive data.
|
||||
|
||||
#### How It Works
|
||||
|
||||
1. **Detection**: PANW scans content and identifies sensitive data
|
||||
2. **Masking**: Sensitive data is replaced with placeholders (e.g., `XXXXXXXXXX` or `{PHONE}`)
|
||||
3. **Pass-through**: Masked content is sent to the LLM or returned to the user
|
||||
|
||||
#### Configuration Options
|
||||
:::warning Important: Masking is Controlled by PANW Security Profile
|
||||
The actual masking behavior (what content gets masked and how) is controlled by your PANW Prisma AIRS security profile in Strata Cloud Manager. The LiteLLM flags (`mask_request_content`, `mask_response_content`) only control whether to apply the masked content and allow the request to continue, or block entirely.
|
||||
:::
|
||||
|
||||
```yaml
|
||||
guardrails:
|
||||
- guardrail_name: "panw-with-masking"
|
||||
litellm_params:
|
||||
guardrail: panw_prisma_airs
|
||||
mode: "post_call" # Scan response output
|
||||
mode: "post_call"
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
|
||||
profile_name: "default"
|
||||
mask_request_content: true # Mask sensitive data in prompts
|
||||
mask_response_content: true # Mask sensitive data in responses
|
||||
mask_request_content: true
|
||||
mask_response_content: true
|
||||
```
|
||||
|
||||
**Masking Parameters:**
|
||||
|
||||
- `mask_request_content: true` - When PANW detects sensitive data in prompts, mask it instead of blocking
|
||||
- `mask_response_content: true` - When PANW detects sensitive data in responses, mask it instead of blocking
|
||||
- `mask_on_block: true` - Backwards compatible flag that enables both request and response masking
|
||||
|
||||
:::warning Important: Masking is Controlled by PANW Security Profile
|
||||
The **actual masking behavior** (what content gets masked and how) is controlled by your **PANW Prisma AIRS security profile** configured in Strata Cloud Manager. The LiteLLM config settings (`mask_request_content`, `mask_response_content`) only control whether to:
|
||||
- **Apply the masked content** returned by PANW and allow the request to continue, OR
|
||||
- **Block the request** entirely when sensitive data is detected
|
||||
|
||||
LiteLLM does not alter or configure your PANW security profile. To change what content gets masked, update your profile settings in Strata Cloud Manager.
|
||||
:::
|
||||
|
||||
:::info Security Posture
|
||||
The guardrail is **fail-closed** by default - if the PANW API is unavailable, requests are blocked to ensure no unscanned content reaches your LLM. This provides maximum security.
|
||||
:::
|
||||
|
||||
### Custom Violation Messages
|
||||
|
||||
You can customize the error message returned to the user when a request is blocked by configuring the `violation_message_template` parameter. This is useful for providing user-friendly feedback instead of technical details.
|
||||
|
||||
```yaml
|
||||
guardrails:
|
||||
- guardrail_name: "panw-custom-message"
|
||||
litellm_params:
|
||||
guardrail: panw_prisma_airs
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
|
||||
# Simple message
|
||||
violation_message_template: "Your request was blocked by our AI Security Policy."
|
||||
|
||||
- guardrail_name: "panw-detailed-message"
|
||||
litellm_params:
|
||||
guardrail: panw_prisma_airs
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
|
||||
# Message with placeholders
|
||||
violation_message_template: "{action_type} blocked due to {category} violation. Please contact support."
|
||||
```
|
||||
|
||||
**Supported Placeholders:**
|
||||
- `{guardrail_name}`: Name of the guardrail (e.g. "panw-custom-message")
|
||||
- `{category}`: Violation category (e.g. "malicious", "injection", "dlp")
|
||||
- `{action_type}`: "Prompt" or "Response"
|
||||
- `{default_message}`: The original technical error message
|
||||
- `mask_request_content: true` — mask sensitive data in prompts instead of blocking
|
||||
- `mask_response_content: true` — mask sensitive data in responses instead of blocking
|
||||
- `mask_on_block: true` — backwards-compatible flag that enables both request and response masking
|
||||
|
||||
### Fail-Open Configuration
|
||||
|
||||
By default, the PANW guardrail operates in **fail-closed** mode for maximum security. If the PANW API is unavailable (timeout, rate limit, network error), requests are blocked. You can configure **fail-open** mode for high-availability scenarios where service continuity is critical.
|
||||
|
||||
```yaml
|
||||
guardrails:
|
||||
- guardrail_name: "panw-high-availability"
|
||||
@@ -488,135 +209,86 @@ guardrails:
|
||||
guardrail: panw_prisma_airs
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
|
||||
profile_name: "production"
|
||||
fallback_on_error: "allow" # Enable fail-open mode
|
||||
timeout: 5.0 # Shorter timeout for fail-open
|
||||
fallback_on_error: "allow"
|
||||
timeout: 5.0
|
||||
```
|
||||
|
||||
**Configuration Options:**
|
||||
|
||||
| Parameter | Value | Behavior |
|
||||
|-----------|-------|----------|
|
||||
| `fallback_on_error` | `"block"` (default) | **Fail-closed**: Block requests when API unavailable (maximum security) |
|
||||
| `fallback_on_error` | `"allow"` | **Fail-open**: Allow requests when API unavailable (high availability) |
|
||||
| `timeout` | `1.0` - `60.0` | API call timeout in seconds (default: `10.0`) |
|
||||
|
||||
**Error Handling Matrix:**
|
||||
|
||||
| Error Type | `fallback_on_error="block"` | `fallback_on_error="allow"` |
|
||||
|------------|----------------------------|----------------------------|
|
||||
| 401 Unauthorized | Block (500) | Block (500) ⚠️ |
|
||||
| 403 Forbidden | Block (500) | Block (500) ⚠️ |
|
||||
| Profile Error | Block (500) | Block (500) ⚠️ |
|
||||
| 401 Unauthorized | Block (500) | Block (500) |
|
||||
| 403 Forbidden | Block (500) | Block (500) |
|
||||
| Profile Error | Block (500) | Block (500) |
|
||||
| 429 Rate Limit | Block (500) | Allow (`:unscanned`) |
|
||||
| Timeout | Block (500) | Allow (`:unscanned`) |
|
||||
| Network Error | Block (500) | Allow (`:unscanned`) |
|
||||
| 5xx Server Error | Block (500) | Allow (`:unscanned`) |
|
||||
| Content Blocked | Block (400) | Block (400) |
|
||||
|
||||
⚠️ = Always blocks regardless of fail-open setting
|
||||
Authentication and configuration errors (401, 403, invalid profile) always block. Only transient errors (429, timeout, network) trigger fail-open.
|
||||
|
||||
:::warning Security Trade-Off
|
||||
Enabling `fallback_on_error="allow"` reduces security in exchange for availability. Requests may proceed **without scanning** when the PANW API is unavailable. Use only when:
|
||||
- Service availability is more critical than security scanning
|
||||
- You have other security controls in place
|
||||
- You monitor the `:unscanned` header for audit trails
|
||||
When fail-open is triggered, the response includes a tracking header: `X-LiteLLM-Applied-Guardrails: panw-airs:unscanned`
|
||||
|
||||
**Authentication and configuration errors (401, 403, invalid profile) always block** - only transient errors (429, timeout, network) trigger fail-open behavior.
|
||||
:::
|
||||
|
||||
**Observability:**
|
||||
|
||||
When fail-open is triggered, the response includes a special header for tracking:
|
||||
|
||||
```
|
||||
X-LiteLLM-Applied-Guardrails: panw-airs:unscanned
|
||||
```
|
||||
|
||||
This allows you to:
|
||||
- Track which requests bypassed scanning
|
||||
- Alert on unscanned request volumes
|
||||
- Audit compliance requirements
|
||||
|
||||
#### Example: Masking Credit Card Numbers
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="Without Masking" value="no-mask">
|
||||
|
||||
**Request:**
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "My credit card is 4929-3813-3266-4295"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Response:** ❌ **Blocked with 400 error**
|
||||
|
||||
</TabItem>
|
||||
<TabItem label="With Masking" value="with-mask">
|
||||
|
||||
**Request:**
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "My credit card is 4929-3813-3266-4295"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Masked prompt sent to LLM:**
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "My credit card is XXXXXXXXXXXXXXXXXX"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Response:** ✅ **Allowed with masked content**
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
#### Masking Capabilities
|
||||
|
||||
The guardrail masks sensitive content in:
|
||||
|
||||
- ✅ **Chat messages** - User prompts and assistant responses
|
||||
- ✅ **Streaming responses** - Real-time masking of streamed content
|
||||
- ✅ **Multi-choice responses** - All choices in the response
|
||||
- ✅ **Tool/function calls** - Arguments passed to tools and functions
|
||||
- ✅ **Content lists** - Mixed content types (text, images, etc.)
|
||||
|
||||
#### Complete Example
|
||||
### Custom Violation Messages
|
||||
|
||||
```yaml
|
||||
guardrails:
|
||||
- guardrail_name: "panw-production-security"
|
||||
- guardrail_name: "panw-custom-message"
|
||||
litellm_params:
|
||||
guardrail: panw_prisma_airs
|
||||
mode: "post_call" # Scan input and output
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
|
||||
profile_name: "production-profile"
|
||||
mask_request_content: true # Mask sensitive prompts
|
||||
mask_response_content: true # Mask sensitive responses
|
||||
violation_message_template: "Your request was blocked by our AI Security Policy."
|
||||
|
||||
- guardrail_name: "panw-detailed-message"
|
||||
litellm_params:
|
||||
guardrail: panw_prisma_airs
|
||||
api_key: os.environ/PANW_PRISMA_AIRS_API_KEY
|
||||
violation_message_template: "{action_type} blocked due to {category} violation. Please contact support."
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
**Supported Placeholders:** `{guardrail_name}`, `{category}`, `{action_type}`, `{default_message}`
|
||||
|
||||
From [official Prisma AIRS documentation](https://docs.paloaltonetworks.com/ai-runtime-security/activation-and-onboarding/ai-runtime-security-api-intercept-overview):
|
||||
## Behavior and Limitations
|
||||
|
||||
- **Secure AI models in production**: Validate prompt requests and responses to protect deployed AI models
|
||||
- **Detect data poisoning**: Identify contaminated training data before fine-tuning
|
||||
- **Protect against adversarial input**: Safeguard AI agents from malicious inputs and outputs
|
||||
- **Prevent sensitive data leakage**: Use API-based threat detection to block sensitive data leaks
|
||||
### Transaction Tracking
|
||||
|
||||
For standard request/response scans, `tr_id` maps to `litellm_call_id`. MCP tool scans use the parent `litellm_call_id` when available; if missing, PANW synthesizes a fallback MCP transaction ID. The real limitation is correlation loss — synthesized MCP `tr_id` values are not grouped with the parent request's prompt/response scans in AIRS dashboards.
|
||||
|
||||
By default, LiteLLM generates a UUID for `litellm_call_id`. To provide your own:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "x-litellm-call-id: my-custom-call-id-789" \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "capital of France"}],
|
||||
"guardrails": ["panw-prisma-airs-guardrail"]
|
||||
}'
|
||||
```
|
||||
|
||||
The `x-litellm-call-id` is also returned in response headers. If you pass `litellm_trace_id` in request metadata (or via the `x-litellm-trace-id` header), it is included in the PANW API payload metadata but does not affect `tr_id` or appear in Prisma AIRS.
|
||||
|
||||
### Streaming
|
||||
|
||||
- Response masking works on OpenAI chat streaming (`mask_response_content: true`)
|
||||
- `/v1/messages` and `/v1/responses` raw streaming blocks instead of masking when violations are detected
|
||||
- Request-side masking (`mask_request_content`) is unaffected by endpoint type
|
||||
- When `fallback_on_error: "allow"` is set, streaming responses fail open on transient PANW API errors (timeout, 5xx, network) — original chunks are yielded unchanged
|
||||
|
||||
## MCP Tool Security
|
||||
|
||||
Tool invocations are sent to AIRS as structured `tool_event` payloads containing tool name, ecosystem, and serialized arguments. Tool-event scans always use request mode.
|
||||
|
||||
**What is scanned:** LLM-driven `tool_calls` (name + arguments) and MCP request-side invocations when `mcp_tool_name` (or fallback `name`) is present. Response-side OpenAI-compatible `tool_calls` are also scanned when surfaced into `apply_guardrail()`.
|
||||
|
||||
**What is not scanned:** Tool definitions in `inputs["tools"]` and post-MCP tool results (no `post_mcp_call` hook exists yet).
|
||||
|
||||
|
||||
## Next Steps
|
||||
### Current Limitations
|
||||
|
||||
- Configure your security policies in [Strata Cloud Manager](https://apps.paloaltonetworks.com/)
|
||||
- Review the [Prisma AIRS API documentation](https://pan.dev/airs/) for advanced features
|
||||
- Set up monitoring and alerting for threat detections in your PANW dashboard
|
||||
- Consider implementing both pre_call and post_call guardrails for comprehensive protection
|
||||
- Monitor detection events and tune your security profiles based on your application needs
|
||||
- **No post-MCP response scanning.** Actual post-MCP tool-result scanning is not supported because there is no `post_mcp_call` hook in the framework. Response-side MCP events are only scanned when they appear as regular `tool_calls` in the LLM response.
|
||||
- **Guardrail selection not inherited by MCP sub-calls.** With `default_on: false`, MCP request-side child-call scans can be skipped because the parent request's guardrail selection is not propagated to the synthetic MCP payload. Workaround: use a dedicated guardrail with `mode: pre_mcp_call` and `default_on: true`.
|
||||
- **MCP transaction correlation.** MCP tool scans use the parent `litellm_call_id` when available; otherwise a fallback ID is synthesized and will not be grouped with the parent request in AIRS dashboards.
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 72 KiB |
@@ -814,6 +814,8 @@ const sidebars = {
|
||||
"providers/anyscale",
|
||||
"providers/apertis",
|
||||
"providers/baseten",
|
||||
"providers/black_forest_labs",
|
||||
"providers/black_forest_labs_img_edit",
|
||||
"providers/bytez",
|
||||
"providers/cerebras",
|
||||
"providers/chutes",
|
||||
|
||||
+13
@@ -0,0 +1,13 @@
|
||||
-- SkipTransactionBlock
|
||||
|
||||
-- Drop invalid indexes left behind by failed CONCURRENTLY builds
|
||||
DROP INDEX CONCURRENTLY IF EXISTS "LiteLLM_VerificationToken_key_alias_idx";
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX CONCURRENTLY "LiteLLM_VerificationToken_key_alias_idx" ON "LiteLLM_VerificationToken"("key_alias");
|
||||
|
||||
-- Drop invalid indexes left behind by failed CONCURRENTLY builds
|
||||
DROP INDEX CONCURRENTLY IF EXISTS "LiteLLM_SpendLogs_user_startTime_idx";
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX CONCURRENTLY "LiteLLM_SpendLogs_user_startTime_idx" ON "LiteLLM_SpendLogs"("user", "startTime");
|
||||
@@ -388,6 +388,9 @@ model LiteLLM_VerificationToken {
|
||||
|
||||
// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3
|
||||
@@index([budget_reset_at, expires])
|
||||
|
||||
// SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (...) ORDER BY "public"."LiteLLM_VerificationToken"."key_alias" ASC
|
||||
@@index([key_alias])
|
||||
}
|
||||
|
||||
model LiteLLM_JWTKeyMapping {
|
||||
@@ -553,6 +556,9 @@ model LiteLLM_SpendLogs {
|
||||
@@index([startTime, request_id])
|
||||
@@index([end_user])
|
||||
@@index([session_id])
|
||||
|
||||
// SELECT ... FROM "LiteLLM_SpendLogs" WHERE ("startTime" >= $1 AND "startTime" <= $2 AND "user" = $3) GROUP BY ...
|
||||
@@index([user, startTime])
|
||||
}
|
||||
|
||||
// View spend, model, api_key per request
|
||||
|
||||
@@ -577,6 +577,7 @@ v0_models: Set = set()
|
||||
morph_models: Set = set()
|
||||
lambda_ai_models: Set = set()
|
||||
hyperbolic_models: Set = set()
|
||||
black_forest_labs_models: Set = set()
|
||||
recraft_models: Set = set()
|
||||
cometapi_models: Set = set()
|
||||
oci_models: Set = set()
|
||||
@@ -824,6 +825,8 @@ def add_known_models(model_cost_map: Optional[Dict] = None):
|
||||
lambda_ai_models.add(key)
|
||||
elif value.get("litellm_provider") == "hyperbolic":
|
||||
hyperbolic_models.add(key)
|
||||
elif value.get("litellm_provider") == "black_forest_labs":
|
||||
black_forest_labs_models.add(key)
|
||||
elif value.get("litellm_provider") == "recraft":
|
||||
recraft_models.add(key)
|
||||
elif value.get("litellm_provider") == "cometapi":
|
||||
@@ -957,6 +960,7 @@ model_list = list(
|
||||
| v0_models
|
||||
| morph_models
|
||||
| lambda_ai_models
|
||||
| black_forest_labs_models
|
||||
| recraft_models
|
||||
| cometapi_models
|
||||
| oci_models
|
||||
@@ -1055,6 +1059,7 @@ models_by_provider: dict = {
|
||||
"morph": morph_models,
|
||||
"lambda_ai": lambda_ai_models,
|
||||
"hyperbolic": hyperbolic_models,
|
||||
"black_forest_labs": black_forest_labs_models,
|
||||
"recraft": recraft_models,
|
||||
"cometapi": cometapi_models,
|
||||
"oci": oci_models,
|
||||
|
||||
@@ -348,8 +348,6 @@ class DualCache(BaseCache):
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
@@ -371,8 +369,6 @@ class DualCache(BaseCache):
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
await self.in_memory_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
|
||||
@@ -398,9 +398,6 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
|
||||
ResponseOutputMessage,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
from openai.types.responses.response_output_item import (
|
||||
ResponseApplyPatchToolCall,
|
||||
)
|
||||
|
||||
from litellm.types.utils import Choices, Message
|
||||
|
||||
@@ -457,18 +454,6 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
|
||||
accumulated_tool_calls.append(tool_call_dict)
|
||||
tool_call_index += 1
|
||||
|
||||
elif isinstance(item, ResponseApplyPatchToolCall):
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
|
||||
tool_call_dict = LiteLLMCompletionResponsesConfig.convert_apply_patch_tool_call_to_chat_completion_tool_call(
|
||||
tool_call_item=item,
|
||||
index=tool_call_index,
|
||||
)
|
||||
accumulated_tool_calls.append(tool_call_dict)
|
||||
tool_call_index += 1
|
||||
|
||||
elif isinstance(item, dict) and handle_raw_dict_callback is not None:
|
||||
# Handle raw dict responses (e.g., from GPT-5 Codex)
|
||||
choice, index = handle_raw_dict_callback(item=item, index=index)
|
||||
|
||||
@@ -1212,12 +1212,8 @@ OPENAI_FINISH_REASONS = [
|
||||
"stop",
|
||||
"length",
|
||||
"function_call",
|
||||
"tool_calls",
|
||||
"content_filter",
|
||||
"null",
|
||||
"finish_reason_unspecified",
|
||||
"malformed_function_call",
|
||||
"guardrail_intervened",
|
||||
"eos",
|
||||
]
|
||||
HUMANLOOP_PROMPT_CACHE_TTL_SECONDS = int(
|
||||
os.getenv("HUMANLOOP_PROMPT_CACHE_TTL_SECONDS", 60)
|
||||
|
||||
@@ -770,8 +770,6 @@ class GoogleGenAIAdapter:
|
||||
"content_filter": "SAFETY",
|
||||
"tool_calls": "STOP",
|
||||
"function_call": "STOP",
|
||||
"finish_reason_unspecified": "FINISH_REASON_UNSPECIFIED",
|
||||
"malformed_function_call": "MALFORMED_FUNCTION_CALL",
|
||||
}
|
||||
|
||||
return mapping.get(finish_reason, "STOP")
|
||||
|
||||
+42
-5
@@ -50,6 +50,10 @@ from litellm.main import (
|
||||
openai_image_variations,
|
||||
)
|
||||
|
||||
# BFL handlers
|
||||
from litellm.llms.black_forest_labs.image_edit.handler import bfl_image_edit
|
||||
from litellm.llms.black_forest_labs.image_generation.handler import bfl_image_generation
|
||||
|
||||
###########################################
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
@@ -426,6 +430,22 @@ def image_generation( # noqa: PLR0915
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
elif custom_llm_provider == "black_forest_labs":
|
||||
# Route to BFL-specific handler (polling required)
|
||||
if model is None:
|
||||
raise Exception("Model needs to be set for black_forest_labs")
|
||||
return bfl_image_generation.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
client=client,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
elif custom_llm_provider == "azure_ai":
|
||||
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
|
||||
|
||||
@@ -909,19 +929,36 @@ def image_edit( # noqa: PLR0915
|
||||
elif custom_llm_provider == "stability":
|
||||
image_edit_request_params.update(non_default_params)
|
||||
return base_llm_http_handler.image_edit_handler(
|
||||
model=model,
|
||||
image=images,
|
||||
prompt=prompt,
|
||||
image_edit_provider_config=image_edit_provider_config,
|
||||
image_edit_optional_request_params=image_edit_request_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout or DEFAULT_REQUEST_TIMEOUT,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client"),
|
||||
)
|
||||
elif custom_llm_provider == "black_forest_labs":
|
||||
# Route to BFL-specific handler (polling required)
|
||||
if model is None:
|
||||
raise Exception("Model needs to be set for black_forest_labs")
|
||||
image_edit_request_params.update(non_default_params)
|
||||
return bfl_image_edit.image_edit(
|
||||
model=model,
|
||||
image=images,
|
||||
prompt=prompt,
|
||||
image_edit_provider_config=image_edit_provider_config,
|
||||
image_edit_optional_request_params=image_edit_request_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout or DEFAULT_REQUEST_TIMEOUT,
|
||||
_is_async=_is_async,
|
||||
extra_headers=extra_headers,
|
||||
client=kwargs.get("client"),
|
||||
aimage_edit=_is_async,
|
||||
)
|
||||
# Call the handler with _is_async flag instead of directly calling the async handler
|
||||
return base_llm_http_handler.image_edit_handler(
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# What is this?
|
||||
## Helper utilities
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAIChatCompletionFinishReason
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
@@ -58,45 +58,55 @@ def safe_divide(
|
||||
return numerator / denominator
|
||||
|
||||
|
||||
def map_finish_reason(
|
||||
finish_reason: str,
|
||||
): # openai supports 5 stop sequences - 'stop', 'length', 'function_call', 'content_filter', 'null'
|
||||
# anthropic mapping
|
||||
if finish_reason == "stop_sequence":
|
||||
_FINISH_REASON_MAP: dict[str, OpenAIChatCompletionFinishReason] = {
|
||||
# Anthropic
|
||||
"stop_sequence": "stop",
|
||||
"end_turn": "stop",
|
||||
"max_tokens": "length",
|
||||
"tool_use": "tool_calls",
|
||||
"compaction": "length",
|
||||
# Cohere
|
||||
"COMPLETE": "stop",
|
||||
"ERROR_TOXIC": "content_filter",
|
||||
"ERROR": "stop",
|
||||
# HuggingFace / Together AI
|
||||
"eos_token": "stop",
|
||||
"eos": "stop",
|
||||
# Gemini / Vertex AI
|
||||
"STOP": "stop",
|
||||
"MAX_TOKENS": "length",
|
||||
"SAFETY": "content_filter",
|
||||
"RECITATION": "content_filter",
|
||||
"FINISH_REASON_UNSPECIFIED": "stop",
|
||||
"MALFORMED_FUNCTION_CALL": "stop",
|
||||
"LANGUAGE": "content_filter",
|
||||
"OTHER": "content_filter",
|
||||
"BLOCKLIST": "content_filter",
|
||||
"PROHIBITED_CONTENT": "content_filter",
|
||||
"SPII": "content_filter",
|
||||
"IMAGE_SAFETY": "content_filter",
|
||||
"IMAGE_PROHIBITED_CONTENT": "content_filter",
|
||||
"TOO_MANY_TOOL_CALLS": "stop",
|
||||
"MALFORMED_RESPONSE": "stop",
|
||||
# Bedrock
|
||||
"guardrail_intervened": "content_filter",
|
||||
# OpenAI passthrough
|
||||
"stop": "stop",
|
||||
"length": "length",
|
||||
"tool_calls": "tool_calls",
|
||||
"function_call": "function_call",
|
||||
"content_filter": "content_filter",
|
||||
}
|
||||
|
||||
|
||||
def map_finish_reason(finish_reason: str) -> OpenAIChatCompletionFinishReason:
|
||||
mapped = _FINISH_REASON_MAP.get(finish_reason)
|
||||
if mapped is None:
|
||||
verbose_logger.warning(
|
||||
"Unmapped finish_reason '%s', defaulting to 'stop'", finish_reason
|
||||
)
|
||||
return "stop"
|
||||
# cohere mapping - https://docs.cohere.com/reference/generate
|
||||
elif finish_reason == "COMPLETE":
|
||||
return "stop"
|
||||
elif finish_reason == "MAX_TOKENS": # cohere + vertex ai
|
||||
return "length"
|
||||
elif finish_reason == "ERROR_TOXIC":
|
||||
return "content_filter"
|
||||
elif (
|
||||
finish_reason == "ERROR"
|
||||
): # openai currently doesn't support an 'error' finish reason
|
||||
return "stop"
|
||||
# huggingface mapping https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate_stream
|
||||
elif finish_reason == "eos_token" or finish_reason == "stop_sequence":
|
||||
return "stop"
|
||||
elif (
|
||||
finish_reason == "FINISH_REASON_UNSPECIFIED"
|
||||
): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',]
|
||||
return "finish_reason_unspecified"
|
||||
elif finish_reason == "MALFORMED_FUNCTION_CALL":
|
||||
return "malformed_function_call"
|
||||
elif finish_reason == "SAFETY" or finish_reason == "RECITATION": # vertex ai
|
||||
return "content_filter"
|
||||
elif finish_reason == "STOP": # vertex ai
|
||||
return "stop"
|
||||
elif finish_reason == "end_turn" or finish_reason == "stop_sequence": # anthropic
|
||||
return "stop"
|
||||
elif finish_reason == "max_tokens": # anthropic
|
||||
return "length"
|
||||
elif finish_reason == "tool_use": # anthropic
|
||||
return "tool_calls"
|
||||
elif finish_reason == "compaction":
|
||||
return "length"
|
||||
return finish_reason
|
||||
return mapped
|
||||
|
||||
|
||||
def remove_index_from_tool_calls(
|
||||
|
||||
@@ -64,10 +64,12 @@ def duration_in_seconds(duration: str) -> int:
|
||||
now = time.time()
|
||||
current_time = datetime.fromtimestamp(now)
|
||||
|
||||
# Calculate target month and year, handling overflow past December
|
||||
total_months = current_time.month - 1 + value # 0-indexed months
|
||||
target_year = current_time.year + total_months // 12
|
||||
target_month = total_months % 12 + 1 # back to 1-indexed
|
||||
if current_time.month == 12:
|
||||
target_year = current_time.year + 1
|
||||
target_month = 1
|
||||
else:
|
||||
target_year = current_time.year
|
||||
target_month = current_time.month + value
|
||||
|
||||
# Determine the day to set for next month
|
||||
target_day = current_time.day
|
||||
|
||||
@@ -11,7 +11,7 @@ export LITELLM_LOCAL_MODEL_COST_MAP=True
|
||||
import json
|
||||
import os
|
||||
from importlib.resources import files
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -186,6 +186,61 @@ def get_model_cost_map_source_info() -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _expand_model_aliases(model_cost: dict) -> dict:
|
||||
"""
|
||||
Expand ``aliases`` lists in model cost entries into top-level entries.
|
||||
|
||||
Each alias gets a reference to the **same** dict object as the canonical
|
||||
entry (zero memory overhead). The ``aliases`` key is removed from the
|
||||
entry so downstream code never sees it.
|
||||
|
||||
If an alias collides with an existing canonical entry the alias is
|
||||
skipped and a warning is logged.
|
||||
"""
|
||||
aliases_to_add: Dict[str, dict] = {}
|
||||
keys_with_aliases: List[str] = []
|
||||
|
||||
for model_name, model_info in model_cost.items():
|
||||
aliases: Optional[list] = model_info.get("aliases")
|
||||
if aliases is None:
|
||||
continue
|
||||
keys_with_aliases.append(model_name)
|
||||
if not isinstance(aliases, list):
|
||||
verbose_logger.warning(
|
||||
"LiteLLM model alias field for '%s' is not a list (got %s) — skipping.",
|
||||
model_name,
|
||||
type(aliases).__name__,
|
||||
)
|
||||
continue
|
||||
if not aliases:
|
||||
continue
|
||||
for alias in aliases:
|
||||
if alias in model_cost:
|
||||
verbose_logger.warning(
|
||||
"LiteLLM model alias conflict: alias '%s' (from '%s') "
|
||||
"already exists as a canonical entry — skipping.",
|
||||
alias,
|
||||
model_name,
|
||||
)
|
||||
continue
|
||||
if alias in aliases_to_add:
|
||||
verbose_logger.warning(
|
||||
"LiteLLM model alias conflict: alias '%s' (from '%s') "
|
||||
"was already claimed by another entry — skipping.",
|
||||
alias,
|
||||
model_name,
|
||||
)
|
||||
continue
|
||||
aliases_to_add[alias] = model_info # same dict reference
|
||||
|
||||
# Remove the ``aliases`` key from entries so it doesn't pollute model info
|
||||
for key in keys_with_aliases:
|
||||
model_cost[key].pop("aliases", None)
|
||||
|
||||
model_cost.update(aliases_to_add)
|
||||
return model_cost
|
||||
|
||||
|
||||
def get_model_cost_map(url: str) -> dict:
|
||||
"""
|
||||
Public entry point — returns the model cost map dict.
|
||||
@@ -205,7 +260,7 @@ def get_model_cost_map(url: str) -> dict:
|
||||
_cost_map_source_info.url = None
|
||||
_cost_map_source_info.is_env_forced = True
|
||||
_cost_map_source_info.fallback_reason = None
|
||||
return GetModelCostMap.load_local_model_cost_map()
|
||||
return _expand_model_aliases(GetModelCostMap.load_local_model_cost_map())
|
||||
|
||||
_cost_map_source_info.url = url
|
||||
_cost_map_source_info.is_env_forced = False
|
||||
@@ -221,7 +276,7 @@ def get_model_cost_map(url: str) -> dict:
|
||||
)
|
||||
_cost_map_source_info.source = "local"
|
||||
_cost_map_source_info.fallback_reason = f"Remote fetch failed: {str(e)}"
|
||||
return GetModelCostMap.load_local_model_cost_map()
|
||||
return _expand_model_aliases(GetModelCostMap.load_local_model_cost_map())
|
||||
|
||||
# Validate using cached count (cheap int comparison, no file I/O)
|
||||
if not GetModelCostMap.validate_model_cost_map(
|
||||
@@ -234,11 +289,9 @@ def get_model_cost_map(url: str) -> dict:
|
||||
url,
|
||||
)
|
||||
_cost_map_source_info.source = "local"
|
||||
_cost_map_source_info.fallback_reason = (
|
||||
"Remote data failed integrity validation"
|
||||
)
|
||||
return GetModelCostMap.load_local_model_cost_map()
|
||||
_cost_map_source_info.fallback_reason = "Remote data failed integrity validation"
|
||||
return _expand_model_aliases(GetModelCostMap.load_local_model_cost_map())
|
||||
|
||||
_cost_map_source_info.source = "remote"
|
||||
_cost_map_source_info.fallback_reason = None
|
||||
return content
|
||||
return _expand_model_aliases(content)
|
||||
|
||||
@@ -150,6 +150,14 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||
return litellm.MistralConfig().get_supported_openai_params(model=model)
|
||||
elif request_type == "embeddings":
|
||||
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
|
||||
elif request_type == "transcription":
|
||||
from litellm.llms.mistral.audio_transcription.transformation import (
|
||||
MistralAudioTranscriptionConfig,
|
||||
)
|
||||
|
||||
return MistralAudioTranscriptionConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
elif custom_llm_provider == "text-completion-codestral":
|
||||
return litellm.CodestralTextCompletionConfig().get_supported_openai_params(
|
||||
model=model
|
||||
|
||||
@@ -2500,74 +2500,257 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||
assistant_content.extend(_compaction_blocks) # type: ignore
|
||||
|
||||
thinking_blocks = assistant_content_block.get("thinking_blocks", None)
|
||||
|
||||
# Check if tool_calls contain server tool calls (web search, etc.)
|
||||
# If so, we need to interleave thinking blocks with tool call groups
|
||||
# to preserve the original content block ordering.
|
||||
# Fixes: https://github.com/BerriAI/litellm/issues/23047
|
||||
assistant_tool_calls = assistant_content_block.get("tool_calls")
|
||||
_has_server_tool_calls = False
|
||||
if assistant_tool_calls is not None:
|
||||
for _tc in assistant_tool_calls:
|
||||
_tc_id = (
|
||||
_tc.get("id")
|
||||
if isinstance(_tc, dict)
|
||||
else getattr(_tc, "id", None)
|
||||
)
|
||||
if _tc_id and isinstance(_tc_id, str) and _tc_id.startswith("srvtoolu_"):
|
||||
_has_server_tool_calls = True
|
||||
break
|
||||
|
||||
if (
|
||||
thinking_blocks is not None
|
||||
): # IMPORTANT: ADD THIS FIRST, ELSE ANTHROPIC WILL RAISE AN ERROR
|
||||
assistant_content.extend(thinking_blocks)
|
||||
if "content" in assistant_content_block and isinstance(
|
||||
assistant_content_block["content"], list
|
||||
and _has_server_tool_calls
|
||||
and isinstance(
|
||||
assistant_content_block.get("content", None), (str, type(None))
|
||||
)
|
||||
):
|
||||
for m in assistant_content_block["content"]:
|
||||
# handle thinking blocks
|
||||
thinking_block = cast(str, m.get("thinking", ""))
|
||||
text_block = cast(str, m.get("text", ""))
|
||||
if (
|
||||
m.get("type", "") == "thinking" and len(thinking_block) > 0
|
||||
): # don't pass empty text blocks. anthropic api raises errors.
|
||||
anthropic_message: Union[
|
||||
ChatCompletionThinkingBlock,
|
||||
AnthropicMessagesTextParam,
|
||||
] = cast(ChatCompletionThinkingBlock, m)
|
||||
assistant_content.append(anthropic_message)
|
||||
# handle text
|
||||
elif (
|
||||
m.get("type", "") == "text" and len(text_block) > 0
|
||||
): # don't pass empty text blocks. anthropic api raises errors.
|
||||
anthropic_message = AnthropicMessagesTextParam(
|
||||
type="text", text=text_block
|
||||
)
|
||||
_cached_message = add_cache_control_to_content(
|
||||
anthropic_content_element=anthropic_message,
|
||||
original_content_element=dict(m),
|
||||
)
|
||||
# INTERLEAVED MODE: When we have both thinking blocks and server
|
||||
# tool calls (e.g. web search), Anthropic's original response
|
||||
# interleaves them: [thinking_1, server_tool_use_1, result_1,
|
||||
# thinking_2, text, server_tool_use_2, result_2, ...].
|
||||
# We must preserve this interleaved order because Anthropic
|
||||
# verifies thinking block signatures based on position.
|
||||
|
||||
assistant_content.append(
|
||||
cast(AnthropicMessagesTextParam, _cached_message)
|
||||
)
|
||||
# handle server_tool_use blocks (tool search, web search, etc.)
|
||||
# Pass through as-is since these are Anthropic-native content types
|
||||
elif m.get("type", "") == "server_tool_use":
|
||||
assistant_content.append(m) # type: ignore
|
||||
# handle all *_tool_result blocks (tool_search_tool_result,
|
||||
# web_search_tool_result, bash_code_execution_tool_result, etc.)
|
||||
# Pass through as-is since these are Anthropic-native content types
|
||||
elif m.get("type", "").endswith("_tool_result"):
|
||||
assistant_content.append(m) # type: ignore
|
||||
elif (
|
||||
"content" in assistant_content_block
|
||||
and isinstance(assistant_content_block["content"], str)
|
||||
and assistant_content_block[
|
||||
"content"
|
||||
] # don't pass empty text blocks. anthropic api raises errors.
|
||||
):
|
||||
_anthropic_text_content_element = AnthropicMessagesTextParam(
|
||||
type="text",
|
||||
text=assistant_content_block["content"],
|
||||
# Build the tool call groups (server_tool_use + its result)
|
||||
_provider_specific_fields_raw_tc = assistant_content_block.get(
|
||||
"provider_specific_fields"
|
||||
)
|
||||
_provider_specific_fields_tc: Dict[str, Any] = {}
|
||||
if isinstance(_provider_specific_fields_raw_tc, dict):
|
||||
_provider_specific_fields_tc = cast(
|
||||
Dict[str, Any], _provider_specific_fields_raw_tc
|
||||
)
|
||||
_web_search_results_tc = _provider_specific_fields_tc.get(
|
||||
"web_search_results"
|
||||
)
|
||||
_tool_results_tc = _provider_specific_fields_tc.get("tool_results")
|
||||
tool_invoke_results = convert_to_anthropic_tool_invoke(
|
||||
assistant_tool_calls, # type: ignore
|
||||
web_search_results=_web_search_results_tc,
|
||||
tool_results=_tool_results_tc,
|
||||
)
|
||||
|
||||
_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=_anthropic_text_content_element,
|
||||
original_content_element=dict(assistant_content_block),
|
||||
# Group tool invoke results into (server_tool_use, result) pairs
|
||||
# and separate regular tool_use blocks
|
||||
server_tool_groups: List[List[Any]] = []
|
||||
regular_tool_uses: List[Any] = []
|
||||
_current_group: List[Any] = []
|
||||
for item in tool_invoke_results:
|
||||
item_type = (
|
||||
item.get("type", "")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "type", "")
|
||||
)
|
||||
if item_type == "server_tool_use":
|
||||
if _current_group:
|
||||
server_tool_groups.append(_current_group)
|
||||
_current_group = [item]
|
||||
elif item_type.endswith("_tool_result"):
|
||||
_current_group.append(item)
|
||||
elif item_type == "tool_use":
|
||||
regular_tool_uses.append(item)
|
||||
else:
|
||||
_current_group.append(item)
|
||||
if _current_group:
|
||||
server_tool_groups.append(_current_group)
|
||||
|
||||
# Build the text block if content is a non-empty string
|
||||
text_element = None
|
||||
if (
|
||||
isinstance(assistant_content_block.get("content"), str)
|
||||
and assistant_content_block["content"]
|
||||
):
|
||||
_anthropic_text_content_element = AnthropicMessagesTextParam(
|
||||
type="text",
|
||||
text=assistant_content_block["content"],
|
||||
)
|
||||
_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=_anthropic_text_content_element,
|
||||
original_content_element=dict(assistant_content_block),
|
||||
)
|
||||
if "cache_control" in _content_element:
|
||||
_anthropic_text_content_element["cache_control"] = (
|
||||
_content_element["cache_control"]
|
||||
)
|
||||
text_element = _anthropic_text_content_element
|
||||
|
||||
# Interleave: each thinking block precedes its server tool group.
|
||||
# Pattern: thinking[0], group[0], thinking[1], group[1], ...
|
||||
# Any remaining thinking blocks (after all groups) go before text.
|
||||
# Any remaining groups (after all thinking blocks) go after.
|
||||
tb_idx = 0
|
||||
grp_idx = 0
|
||||
num_tb = len(thinking_blocks) if thinking_blocks else 0
|
||||
num_grp = len(server_tool_groups)
|
||||
|
||||
while tb_idx < num_tb or grp_idx < num_grp:
|
||||
if tb_idx < num_tb and grp_idx < num_grp:
|
||||
# Emit thinking block then its tool group
|
||||
assistant_content.append(thinking_blocks[tb_idx])
|
||||
tb_idx += 1
|
||||
for block in server_tool_groups[grp_idx]:
|
||||
item_id = (
|
||||
block.get("id")
|
||||
if isinstance(block, dict)
|
||||
else getattr(block, "id", None)
|
||||
)
|
||||
if item_id and item_id in unique_tool_ids:
|
||||
continue
|
||||
if item_id:
|
||||
unique_tool_ids.add(item_id)
|
||||
assistant_content.append(
|
||||
cast(AnthropicMessagesAssistantMessageValues, block)
|
||||
)
|
||||
grp_idx += 1
|
||||
elif tb_idx < num_tb:
|
||||
# More thinking blocks than tool groups - emit before text
|
||||
assistant_content.append(thinking_blocks[tb_idx])
|
||||
tb_idx += 1
|
||||
else:
|
||||
# More tool groups than thinking blocks - emit remaining
|
||||
for block in server_tool_groups[grp_idx]:
|
||||
item_id = (
|
||||
block.get("id")
|
||||
if isinstance(block, dict)
|
||||
else getattr(block, "id", None)
|
||||
)
|
||||
if item_id and item_id in unique_tool_ids:
|
||||
continue
|
||||
if item_id:
|
||||
unique_tool_ids.add(item_id)
|
||||
assistant_content.append(
|
||||
cast(AnthropicMessagesAssistantMessageValues, block)
|
||||
)
|
||||
grp_idx += 1
|
||||
|
||||
# Add text block (if any)
|
||||
if text_element is not None:
|
||||
assistant_content.append(text_element)
|
||||
|
||||
# Add regular (non-server) tool calls at the end
|
||||
for item in regular_tool_uses:
|
||||
item_id = (
|
||||
item.get("id")
|
||||
if isinstance(item, dict)
|
||||
else getattr(item, "id", None)
|
||||
)
|
||||
if item_id and item_id in unique_tool_ids:
|
||||
continue
|
||||
if item_id:
|
||||
unique_tool_ids.add(item_id)
|
||||
assistant_content.append(
|
||||
cast(AnthropicMessagesAssistantMessageValues, item)
|
||||
)
|
||||
|
||||
# Mark tool_calls as already processed so they are not added again
|
||||
assistant_tool_calls = None
|
||||
|
||||
else:
|
||||
# SEQUENTIAL MODE: No server tool calls, or no thinking blocks,
|
||||
# or content is a list. Use the original sequential approach.
|
||||
|
||||
# When content is a list, check if it already contains thinking
|
||||
# blocks inline. If so, skip prepending thinking_blocks to avoid
|
||||
# duplication and preserve the original interleaved order.
|
||||
# Fixes the gap where list-content messages bypass INTERLEAVED
|
||||
# MODE and still get thinking blocks prepended out of order.
|
||||
_content_is_list = "content" in assistant_content_block and isinstance(
|
||||
assistant_content_block["content"], list
|
||||
)
|
||||
_list_has_thinking = False
|
||||
if _content_is_list:
|
||||
for _item in assistant_content_block["content"]:
|
||||
if isinstance(_item, dict) and _item.get("type") in ("thinking", "redacted_thinking"):
|
||||
_list_has_thinking = True
|
||||
break
|
||||
|
||||
if "cache_control" in _content_element:
|
||||
_anthropic_text_content_element["cache_control"] = _content_element[
|
||||
"cache_control"
|
||||
]
|
||||
if (
|
||||
thinking_blocks is not None
|
||||
and not _list_has_thinking
|
||||
): # IMPORTANT: ADD THIS FIRST, ELSE ANTHROPIC WILL RAISE AN ERROR
|
||||
assistant_content.extend(thinking_blocks)
|
||||
if _content_is_list:
|
||||
for m in assistant_content_block["content"]:
|
||||
# handle thinking blocks
|
||||
thinking_block = cast(str, m.get("thinking", ""))
|
||||
text_block = cast(str, m.get("text", ""))
|
||||
if (
|
||||
m.get("type", "") == "thinking" and len(thinking_block) > 0
|
||||
): # don't pass empty text blocks. anthropic api raises errors.
|
||||
anthropic_message: Union[
|
||||
ChatCompletionThinkingBlock,
|
||||
AnthropicMessagesTextParam,
|
||||
] = cast(ChatCompletionThinkingBlock, m)
|
||||
assistant_content.append(anthropic_message)
|
||||
# handle text
|
||||
elif (
|
||||
m.get("type", "") == "text" and len(text_block) > 0
|
||||
): # don't pass empty text blocks. anthropic api raises errors.
|
||||
anthropic_message = AnthropicMessagesTextParam(
|
||||
type="text", text=text_block
|
||||
)
|
||||
_cached_message = add_cache_control_to_content(
|
||||
anthropic_content_element=anthropic_message,
|
||||
original_content_element=dict(m),
|
||||
)
|
||||
|
||||
assistant_content.append(_anthropic_text_content_element)
|
||||
assistant_content.append(
|
||||
cast(AnthropicMessagesTextParam, _cached_message)
|
||||
)
|
||||
# handle server_tool_use blocks (tool search, web search, etc.)
|
||||
# Pass through as-is since these are Anthropic-native content types
|
||||
elif m.get("type", "") == "server_tool_use":
|
||||
assistant_content.append(m) # type: ignore
|
||||
# handle all *_tool_result blocks (tool_search_tool_result,
|
||||
# web_search_tool_result, bash_code_execution_tool_result, etc.)
|
||||
# Pass through as-is since these are Anthropic-native content types
|
||||
elif m.get("type", "").endswith("_tool_result"):
|
||||
assistant_content.append(m) # type: ignore
|
||||
elif (
|
||||
"content" in assistant_content_block
|
||||
and isinstance(assistant_content_block["content"], str)
|
||||
and assistant_content_block[
|
||||
"content"
|
||||
] # don't pass empty text blocks. anthropic api raises errors.
|
||||
):
|
||||
_anthropic_text_content_element = AnthropicMessagesTextParam(
|
||||
type="text",
|
||||
text=assistant_content_block["content"],
|
||||
)
|
||||
|
||||
_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=_anthropic_text_content_element,
|
||||
original_content_element=dict(assistant_content_block),
|
||||
)
|
||||
|
||||
if "cache_control" in _content_element:
|
||||
_anthropic_text_content_element["cache_control"] = _content_element[
|
||||
"cache_control"
|
||||
]
|
||||
|
||||
assistant_content.append(_anthropic_text_content_element)
|
||||
|
||||
assistant_tool_calls = assistant_content_block.get("tool_calls")
|
||||
if (
|
||||
assistant_tool_calls is not None
|
||||
): # support assistant tool invoke conversion
|
||||
|
||||
@@ -75,53 +75,6 @@ def _redact_responses_api_output(output_items):
|
||||
summary_item.text = "redacted-by-litellm"
|
||||
|
||||
|
||||
def _redact_standard_logging_object(model_call_details: dict):
|
||||
"""Redact messages and response inside standard_logging_object if present."""
|
||||
standard_logging_object = model_call_details.get("standard_logging_object")
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
|
||||
redacted_str = "redacted-by-litellm"
|
||||
|
||||
if standard_logging_object.get("messages") is not None:
|
||||
standard_logging_object["messages"] = [
|
||||
{"role": "user", "content": redacted_str}
|
||||
]
|
||||
|
||||
response = standard_logging_object.get("response")
|
||||
if response is not None:
|
||||
if isinstance(response, dict) and "output" in response:
|
||||
# ResponsesAPIResponse format - redact content in output items
|
||||
if isinstance(response.get("output"), list):
|
||||
for output_item in response["output"]:
|
||||
if isinstance(output_item, dict) and "content" in output_item:
|
||||
if isinstance(output_item["content"], list):
|
||||
for content_item in output_item["content"]:
|
||||
if (
|
||||
isinstance(content_item, dict)
|
||||
and "text" in content_item
|
||||
):
|
||||
content_item["text"] = redacted_str
|
||||
elif isinstance(response, dict) and "choices" in response:
|
||||
# ModelResponse dict format - redact content in choices
|
||||
if isinstance(response.get("choices"), list):
|
||||
for choice in response["choices"]:
|
||||
if isinstance(choice, dict):
|
||||
if "message" in choice and isinstance(choice["message"], dict):
|
||||
choice["message"]["content"] = redacted_str
|
||||
if "audio" in choice["message"]:
|
||||
choice["message"]["audio"] = None
|
||||
elif "delta" in choice and isinstance(choice["delta"], dict):
|
||||
choice["delta"]["content"] = redacted_str
|
||||
if "audio" in choice["delta"]:
|
||||
choice["delta"]["audio"] = None
|
||||
elif isinstance(response, str):
|
||||
standard_logging_object["response"] = redacted_str
|
||||
else:
|
||||
# For other formats (empty dict, None, etc.), use simple text format
|
||||
standard_logging_object["response"] = {"text": redacted_str}
|
||||
|
||||
|
||||
def perform_redaction(model_call_details: dict, result):
|
||||
"""
|
||||
Performs the actual redaction on the logging object and result.
|
||||
|
||||
@@ -4,10 +4,7 @@ from typing import List
|
||||
|
||||
import litellm
|
||||
from litellm.exceptions import UnsupportedParamsError
|
||||
from litellm.llms.openai.chat.gpt_5_transformation import (
|
||||
OpenAIGPT5Config,
|
||||
_get_effort_level,
|
||||
)
|
||||
from litellm.llms.openai.chat.gpt_5_transformation import OpenAIGPT5Config
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
from .gpt_transformation import AzureOpenAIConfig
|
||||
@@ -84,27 +81,24 @@ class AzureOpenAIGPT5Config(AzureOpenAIConfig, OpenAIGPT5Config):
|
||||
drop_params: bool,
|
||||
api_version: str = "",
|
||||
) -> dict:
|
||||
reasoning_effort_value = non_default_params.get(
|
||||
"reasoning_effort"
|
||||
) or optional_params.get("reasoning_effort")
|
||||
effective_effort = _get_effort_level(reasoning_effort_value)
|
||||
reasoning_effort_value = (
|
||||
non_default_params.get("reasoning_effort")
|
||||
or optional_params.get("reasoning_effort")
|
||||
)
|
||||
|
||||
# gpt-5.1/5.2/5.4 support reasoning_effort='none', but other gpt-5 models don't
|
||||
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/reasoning
|
||||
supports_none = self._supports_reasoning_effort_level(model, "none")
|
||||
|
||||
if effective_effort == "none" and not supports_none:
|
||||
if reasoning_effort_value == "none" and not supports_none:
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
non_default_params = non_default_params.copy()
|
||||
optional_params = optional_params.copy()
|
||||
if (
|
||||
_get_effort_level(non_default_params.get("reasoning_effort"))
|
||||
== "none"
|
||||
):
|
||||
if non_default_params.get("reasoning_effort") == "none":
|
||||
non_default_params.pop("reasoning_effort")
|
||||
if _get_effort_level(optional_params.get("reasoning_effort")) == "none":
|
||||
if optional_params.get("reasoning_effort") == "none":
|
||||
optional_params.pop("reasoning_effort")
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
@@ -127,19 +121,9 @@ class AzureOpenAIGPT5Config(AzureOpenAIConfig, OpenAIGPT5Config):
|
||||
)
|
||||
|
||||
# Only drop reasoning_effort='none' for models that don't support it
|
||||
result_effort = _get_effort_level(result.get("reasoning_effort"))
|
||||
if result_effort == "none" and not supports_none:
|
||||
if result.get("reasoning_effort") == "none" and not supports_none:
|
||||
result.pop("reasoning_effort")
|
||||
|
||||
# Azure Chat Completions: gpt-5.4+ does not support tools + reasoning together.
|
||||
# Drop reasoning_effort when both are present (OpenAI routes to Responses API; Azure does not).
|
||||
if self.is_model_gpt_5_4_plus_model(model):
|
||||
has_tools = bool(
|
||||
non_default_params.get("tools") or optional_params.get("tools")
|
||||
)
|
||||
if has_tools and result_effort not in (None, "none"):
|
||||
result.pop("reasoning_effort", None)
|
||||
|
||||
return result
|
||||
|
||||
def transform_request(
|
||||
|
||||
@@ -51,7 +51,6 @@ from litellm.types.llms.openai import (
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
CompletionTokensDetailsWrapper,
|
||||
Function,
|
||||
Message,
|
||||
ModelResponse,
|
||||
@@ -64,7 +63,6 @@ from litellm.utils import (
|
||||
has_tool_call_blocks,
|
||||
last_assistant_with_tool_calls_has_no_thinking_blocks,
|
||||
supports_reasoning,
|
||||
token_counter,
|
||||
)
|
||||
|
||||
from ..common_utils import (
|
||||
@@ -1641,11 +1639,7 @@ class AmazonConverseConfig(BaseConfig):
|
||||
thinking_blocks_list.append(_redacted_block)
|
||||
return thinking_blocks_list
|
||||
|
||||
def _transform_usage(
|
||||
self,
|
||||
usage: ConverseTokenUsageBlock,
|
||||
reasoning_content: Optional[str] = None,
|
||||
) -> Usage:
|
||||
def _transform_usage(self, usage: ConverseTokenUsageBlock) -> Usage:
|
||||
input_tokens = usage["inputTokens"]
|
||||
output_tokens = usage["outputTokens"]
|
||||
total_tokens = usage["totalTokens"]
|
||||
@@ -1662,19 +1656,6 @@ class AmazonConverseConfig(BaseConfig):
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
reasoning_tokens = (
|
||||
token_counter(text=reasoning_content, count_response_tokens=True)
|
||||
if reasoning_content
|
||||
else 0
|
||||
)
|
||||
completion_tokens_details = CompletionTokensDetailsWrapper(
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
text_tokens=(
|
||||
output_tokens - reasoning_tokens
|
||||
if reasoning_tokens > 0
|
||||
else output_tokens
|
||||
),
|
||||
)
|
||||
openai_usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
@@ -1682,7 +1663,6 @@ class AmazonConverseConfig(BaseConfig):
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
completion_tokens_details=completion_tokens_details,
|
||||
)
|
||||
return openai_usage
|
||||
|
||||
@@ -2019,10 +1999,7 @@ class AmazonConverseConfig(BaseConfig):
|
||||
chat_completion_message["tool_calls"] = filtered_tools
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
usage = self._transform_usage(
|
||||
completion_response["usage"],
|
||||
reasoning_content=chat_completion_message.get("reasoning_content"),
|
||||
)
|
||||
usage = self._transform_usage(completion_response["usage"])
|
||||
|
||||
## HANDLE TOOL CALLS
|
||||
_message = Message(**chat_completion_message)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from .common_utils import (
|
||||
DEFAULT_API_BASE,
|
||||
DEFAULT_MAX_POLLING_TIME,
|
||||
DEFAULT_POLLING_INTERVAL,
|
||||
IMAGE_EDIT_MODELS,
|
||||
IMAGE_GENERATION_MODELS,
|
||||
BlackForestLabsError,
|
||||
)
|
||||
from .image_edit import BlackForestLabsImageEditConfig
|
||||
from .image_generation import BlackForestLabsImageGenerationConfig
|
||||
|
||||
__all__ = [
|
||||
"BlackForestLabsError",
|
||||
"BlackForestLabsImageEditConfig",
|
||||
"BlackForestLabsImageGenerationConfig",
|
||||
"DEFAULT_API_BASE",
|
||||
"DEFAULT_MAX_POLLING_TIME",
|
||||
"DEFAULT_POLLING_INTERVAL",
|
||||
"IMAGE_EDIT_MODELS",
|
||||
"IMAGE_GENERATION_MODELS",
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Black Forest Labs Common Utilities
|
||||
|
||||
Common utilities, constants, and error handling for Black Forest Labs API.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class BlackForestLabsError(BaseLLMException):
|
||||
"""Exception class for Black Forest Labs API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# API Constants
|
||||
DEFAULT_API_BASE = "https://api.bfl.ai"
|
||||
|
||||
# Polling configuration
|
||||
DEFAULT_POLLING_INTERVAL = 1.5 # seconds
|
||||
DEFAULT_MAX_POLLING_TIME = 300 # 5 minutes
|
||||
|
||||
# Model to endpoint mapping for image edit
|
||||
IMAGE_EDIT_MODELS: Dict[str, str] = {
|
||||
"flux-kontext-pro": "/v1/flux-kontext-pro",
|
||||
"flux-kontext-max": "/v1/flux-kontext-max",
|
||||
"flux-pro-1.0-fill": "/v1/flux-pro-1.0-fill",
|
||||
"flux-pro-1.0-expand": "/v1/flux-pro-1.0-expand",
|
||||
}
|
||||
|
||||
# Model to endpoint mapping for image generation
|
||||
IMAGE_GENERATION_MODELS: Dict[str, str] = {
|
||||
"flux-pro-1.1": "/v1/flux-pro-1.1",
|
||||
"flux-pro-1.1-ultra": "/v1/flux-pro-1.1-ultra",
|
||||
"flux-dev": "/v1/flux-dev",
|
||||
"flux-pro": "/v1/flux-pro",
|
||||
# Kontext models support both text-to-image and image editing
|
||||
"flux-kontext-pro": "/v1/flux-kontext-pro",
|
||||
"flux-kontext-max": "/v1/flux-kontext-max",
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
from .handler import BlackForestLabsImageEdit, bfl_image_edit
|
||||
from .transformation import BlackForestLabsImageEditConfig
|
||||
|
||||
__all__ = [
|
||||
"BlackForestLabsImageEditConfig",
|
||||
"BlackForestLabsImageEdit",
|
||||
"bfl_image_edit",
|
||||
]
|
||||
@@ -0,0 +1,454 @@
|
||||
"""
|
||||
Black Forest Labs Image Edit Handler
|
||||
|
||||
Handles image edit requests for Black Forest Labs models.
|
||||
BFL uses an async polling pattern - the initial request returns a task ID,
|
||||
then we poll until the result is ready.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageResponse
|
||||
|
||||
from ..common_utils import (
|
||||
DEFAULT_MAX_POLLING_TIME,
|
||||
DEFAULT_POLLING_INTERVAL,
|
||||
BlackForestLabsError,
|
||||
)
|
||||
from .transformation import BlackForestLabsImageEditConfig
|
||||
|
||||
|
||||
class BlackForestLabsImageEdit:
|
||||
"""
|
||||
Black Forest Labs Image Edit handler.
|
||||
|
||||
Handles the HTTP requests and polling logic, delegating data transformation
|
||||
to the BlackForestLabsImageEditConfig class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = BlackForestLabsImageEditConfig()
|
||||
|
||||
def image_edit(
|
||||
self,
|
||||
model: str,
|
||||
image: Union[FileTypes, List[FileTypes]],
|
||||
prompt: Optional[str],
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: Union[GenericLiteLLMParams, Dict],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
aimage_edit: bool = False,
|
||||
) -> Union[ImageResponse, Any]:
|
||||
"""
|
||||
Main entry point for image edit requests.
|
||||
|
||||
Args:
|
||||
model: The model to use (e.g., "black_forest_labs/flux-kontext-pro")
|
||||
image: The image(s) to edit
|
||||
prompt: The edit instruction
|
||||
image_edit_optional_request_params: Optional parameters for the request
|
||||
litellm_params: LiteLLM parameters including api_key, api_base
|
||||
logging_obj: Logging object
|
||||
timeout: Request timeout
|
||||
extra_headers: Additional headers
|
||||
client: HTTP client to use
|
||||
aimage_edit: If True, return async coroutine
|
||||
|
||||
Returns:
|
||||
ImageResponse or coroutine if aimage_edit=True
|
||||
"""
|
||||
# Handle litellm_params as dict or object
|
||||
if isinstance(litellm_params, dict):
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_base = litellm_params.get("api_base")
|
||||
litellm_params_dict = litellm_params
|
||||
else:
|
||||
api_key = litellm_params.api_key
|
||||
api_base = litellm_params.api_base
|
||||
litellm_params_dict = dict(litellm_params)
|
||||
|
||||
if aimage_edit:
|
||||
return self.async_image_edit(
|
||||
model=model,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
image_edit_optional_request_params=image_edit_optional_request_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
client=client if isinstance(client, AsyncHTTPHandler) else None,
|
||||
)
|
||||
|
||||
# Sync version
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_client = _get_httpx_client()
|
||||
else:
|
||||
sync_client = client
|
||||
|
||||
# Validate environment and get headers
|
||||
headers = self.config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=image_edit_optional_request_params.get("extra_headers", {}) or {},
|
||||
model=model,
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# Get complete URL
|
||||
complete_url = self.config.get_complete_url(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
|
||||
# Transform request
|
||||
# Handle image list vs single image
|
||||
if isinstance(image, list):
|
||||
if not image:
|
||||
raise BlackForestLabsError(status_code=400, message="No image provided")
|
||||
image_input = image[0]
|
||||
else:
|
||||
image_input = image
|
||||
data, _ = self.config.transform_image_edit_request(
|
||||
model=model,
|
||||
prompt=prompt or "",
|
||||
image=image_input,
|
||||
image_edit_optional_request_params=image_edit_optional_request_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Logging
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": complete_url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
# Make initial request
|
||||
try:
|
||||
response = sync_client.post(
|
||||
url=complete_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message=f"Request failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
final_response = self._poll_for_result_sync(
|
||||
initial_response=response,
|
||||
headers=headers,
|
||||
sync_client=sync_client,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
return self.config.transform_image_edit_response(
|
||||
model=model,
|
||||
raw_response=final_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async def async_image_edit(
|
||||
self,
|
||||
model: str,
|
||||
image: Union[FileTypes, List[FileTypes]],
|
||||
prompt: Optional[str],
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: Union[GenericLiteLLMParams, Dict],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Async version of image edit.
|
||||
"""
|
||||
# Handle litellm_params as dict or object
|
||||
if isinstance(litellm_params, dict):
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_base = litellm_params.get("api_base")
|
||||
litellm_params_dict = litellm_params
|
||||
else:
|
||||
api_key = litellm_params.api_key
|
||||
api_base = litellm_params.api_base
|
||||
litellm_params_dict = dict(litellm_params)
|
||||
|
||||
if client is None:
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BLACK_FOREST_LABS,
|
||||
)
|
||||
else:
|
||||
async_client = client
|
||||
|
||||
# Validate environment and get headers
|
||||
headers = self.config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=image_edit_optional_request_params.get("extra_headers", {}) or {},
|
||||
model=model,
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# Get complete URL
|
||||
complete_url = self.config.get_complete_url(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
|
||||
# Transform request
|
||||
if isinstance(image, list):
|
||||
if not image:
|
||||
raise BlackForestLabsError(status_code=400, message="No image provided")
|
||||
image_input = image[0]
|
||||
else:
|
||||
image_input = image
|
||||
data, _ = self.config.transform_image_edit_request(
|
||||
model=model,
|
||||
prompt=prompt or "",
|
||||
image=image_input,
|
||||
image_edit_optional_request_params=image_edit_optional_request_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Logging
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": complete_url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
# Make initial request
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=complete_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message=f"Request failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
final_response = await self._poll_for_result_async(
|
||||
initial_response=response,
|
||||
headers=headers,
|
||||
async_client=async_client,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
return self.config.transform_image_edit_response(
|
||||
model=model,
|
||||
raw_response=final_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
def _poll_for_result_sync(
|
||||
self,
|
||||
initial_response: httpx.Response,
|
||||
headers: dict,
|
||||
sync_client: HTTPHandler,
|
||||
max_wait: float = DEFAULT_MAX_POLLING_TIME,
|
||||
interval: float = DEFAULT_POLLING_INTERVAL,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Poll BFL API until result is ready (sync version).
|
||||
|
||||
Args:
|
||||
initial_response: The initial response containing polling_url
|
||||
headers: Headers to use for polling (must include x-key)
|
||||
sync_client: HTTP client
|
||||
max_wait: Maximum time to wait in seconds
|
||||
interval: Polling interval in seconds
|
||||
timeout: Timeout for each individual polling request
|
||||
|
||||
Returns:
|
||||
Final response with completed result
|
||||
"""
|
||||
# Validate initial response status code
|
||||
if initial_response.status_code >= 400:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL initial request failed: {initial_response.text}",
|
||||
)
|
||||
|
||||
# Parse initial response to get polling URL
|
||||
try:
|
||||
response_data = initial_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"Error parsing initial response: {e}",
|
||||
)
|
||||
|
||||
# Check for immediate errors
|
||||
if "errors" in response_data:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL error: {response_data['errors']}",
|
||||
)
|
||||
|
||||
polling_url = response_data.get("polling_url")
|
||||
if not polling_url:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
start_time = time.time()
|
||||
verbose_logger.debug(f"BFL starting sync polling at {polling_url}")
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
response = sync_client.get(
|
||||
url=polling_url,
|
||||
headers=polling_headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BlackForestLabsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Polling failed: {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
|
||||
verbose_logger.debug(f"BFL poll status: {status}")
|
||||
|
||||
if status == "Ready":
|
||||
return response
|
||||
elif status in ["Error", "Failed", "Content Moderated", "Request Moderated"]:
|
||||
raise BlackForestLabsError(
|
||||
status_code=400,
|
||||
message=f"Image generation failed: {status}",
|
||||
)
|
||||
|
||||
time.sleep(interval)
|
||||
|
||||
raise BlackForestLabsError(
|
||||
status_code=408,
|
||||
message=f"Polling timed out after {max_wait} seconds",
|
||||
)
|
||||
|
||||
async def _poll_for_result_async(
|
||||
self,
|
||||
initial_response: httpx.Response,
|
||||
headers: dict,
|
||||
async_client: AsyncHTTPHandler,
|
||||
max_wait: float = DEFAULT_MAX_POLLING_TIME,
|
||||
interval: float = DEFAULT_POLLING_INTERVAL,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Poll BFL API until result is ready (async version).
|
||||
"""
|
||||
# Validate initial response status code
|
||||
if initial_response.status_code >= 400:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL initial request failed: {initial_response.text}",
|
||||
)
|
||||
|
||||
# Parse initial response to get polling URL
|
||||
try:
|
||||
response_data = initial_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"Error parsing initial response: {e}",
|
||||
)
|
||||
|
||||
# Check for immediate errors
|
||||
if "errors" in response_data:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL error: {response_data['errors']}",
|
||||
)
|
||||
|
||||
polling_url = response_data.get("polling_url")
|
||||
if not polling_url:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
start_time = time.time()
|
||||
verbose_logger.debug(f"BFL starting async polling at {polling_url}")
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
response = await async_client.get(
|
||||
url=polling_url,
|
||||
headers=polling_headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BlackForestLabsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Polling failed: {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
|
||||
verbose_logger.debug(f"BFL poll status: {status}")
|
||||
|
||||
if status == "Ready":
|
||||
return response
|
||||
elif status in ["Error", "Failed", "Content Moderated", "Request Moderated"]:
|
||||
raise BlackForestLabsError(
|
||||
status_code=400,
|
||||
message=f"Image generation failed: {status}",
|
||||
)
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
raise BlackForestLabsError(
|
||||
status_code=408,
|
||||
message=f"Polling timed out after {max_wait} seconds",
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance for use in images/main.py
|
||||
bfl_image_edit = BlackForestLabsImageEdit()
|
||||
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Black Forest Labs Image Edit Configuration
|
||||
|
||||
Handles transformation between OpenAI-compatible format and Black Forest Labs API format
|
||||
for image editing endpoints (flux-kontext-pro, flux-kontext-max, etc.).
|
||||
|
||||
API Reference: https://docs.bfl.ai/
|
||||
"""
|
||||
|
||||
import base64
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageObject, ImageResponse
|
||||
|
||||
from ..common_utils import (
|
||||
DEFAULT_API_BASE,
|
||||
IMAGE_EDIT_MODELS,
|
||||
BlackForestLabsError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BlackForestLabsImageEditConfig(BaseImageEditConfig):
|
||||
"""
|
||||
Configuration for Black Forest Labs image editing.
|
||||
|
||||
Supports:
|
||||
- flux-kontext-pro: General image editing with prompts
|
||||
- flux-kontext-max: Premium quality editing
|
||||
- flux-pro-1.0-fill: Inpainting with mask
|
||||
- flux-pro-1.0-expand: Outpainting (expand image borders)
|
||||
|
||||
Note: HTTP requests and polling are handled by the handler (handler.py).
|
||||
This class only handles data transformation.
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Return list of OpenAI params supported by Black Forest Labs.
|
||||
|
||||
Note: BFL uses different parameter names, these are mapped in map_openai_params.
|
||||
"""
|
||||
return [
|
||||
"mask",
|
||||
"seed",
|
||||
"output_format",
|
||||
"safety_tolerance",
|
||||
"prompt_upsampling",
|
||||
"aspect_ratio",
|
||||
"steps",
|
||||
"guidance",
|
||||
"grow_mask",
|
||||
"top",
|
||||
"bottom",
|
||||
"left",
|
||||
"right",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map OpenAI parameters to Black Forest Labs parameters.
|
||||
|
||||
BFL-specific params are passed through directly.
|
||||
"""
|
||||
optional_params: Dict[str, Any] = {}
|
||||
|
||||
# Pass through BFL-specific params
|
||||
bfl_params = [
|
||||
"seed",
|
||||
"output_format",
|
||||
"safety_tolerance",
|
||||
"prompt_upsampling",
|
||||
# Kontext-specific
|
||||
"aspect_ratio",
|
||||
# Fill/Inpaint-specific
|
||||
"steps",
|
||||
"guidance",
|
||||
"grow_mask",
|
||||
# Expand-specific
|
||||
"top",
|
||||
"bottom",
|
||||
"left",
|
||||
"right",
|
||||
]
|
||||
|
||||
# Convert TypedDict to regular dict for access
|
||||
params_dict = dict(image_edit_optional_params)
|
||||
|
||||
for param in bfl_params:
|
||||
if param in params_dict:
|
||||
value = params_dict[param]
|
||||
if value is not None:
|
||||
optional_params[param] = value
|
||||
|
||||
# Set default output format
|
||||
if "output_format" not in optional_params:
|
||||
optional_params["output_format"] = "png"
|
||||
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and set up headers for Black Forest Labs.
|
||||
|
||||
BFL uses x-key header for authentication.
|
||||
"""
|
||||
final_api_key: Optional[str] = (
|
||||
api_key
|
||||
or get_secret_str("BFL_API_KEY")
|
||||
or get_secret_str("BLACK_FOREST_LABS_API_KEY")
|
||||
)
|
||||
|
||||
if not final_api_key:
|
||||
raise BlackForestLabsError(
|
||||
status_code=401,
|
||||
message="BFL_API_KEY is not set. Please set it via environment variable or pass api_key parameter.",
|
||||
)
|
||||
|
||||
headers["x-key"] = final_api_key
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["Accept"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def use_multipart_form_data(self) -> bool:
|
||||
"""
|
||||
BFL uses JSON requests, not multipart/form-data.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _get_model_endpoint(self, model: str) -> str:
|
||||
"""
|
||||
Get the API endpoint for a given model.
|
||||
"""
|
||||
# Remove provider prefix if present (e.g., "black_forest_labs/flux-kontext-pro")
|
||||
model_name = model.lower()
|
||||
if "/" in model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
# Check if model is in our mapping
|
||||
if model_name in IMAGE_EDIT_MODELS:
|
||||
return IMAGE_EDIT_MODELS[model_name]
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown BFL image edit model: {model_name}. "
|
||||
f"Supported models: {list(IMAGE_EDIT_MODELS.keys())}"
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the Black Forest Labs API request.
|
||||
"""
|
||||
base_url: str = (
|
||||
api_base
|
||||
or get_secret_str("BFL_API_BASE")
|
||||
or DEFAULT_API_BASE
|
||||
)
|
||||
base_url = base_url.rstrip("/")
|
||||
|
||||
endpoint = self._get_model_endpoint(model)
|
||||
return f"{base_url}{endpoint}"
|
||||
|
||||
def _read_image_bytes(self, image: Any) -> bytes:
|
||||
"""Read image bytes from various input types."""
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
elif isinstance(image, list):
|
||||
# If it's a list, take the first image
|
||||
return self._read_image_bytes(image[0])
|
||||
elif isinstance(image, str):
|
||||
if image.startswith(("http://", "https://")):
|
||||
# Download image from URL
|
||||
response = httpx.get(image, timeout=60.0)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
else:
|
||||
# Assume it's a file path
|
||||
with open(image, "rb") as f:
|
||||
return f.read()
|
||||
elif hasattr(image, "read"):
|
||||
# File-like object
|
||||
pos = getattr(image, "tell", lambda: 0)()
|
||||
if hasattr(image, "seek"):
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
if hasattr(image, "seek"):
|
||||
image.seek(pos)
|
||||
return data
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported image type: {type(image)}. "
|
||||
"Expected bytes, str (URL or file path), or file-like object."
|
||||
)
|
||||
|
||||
def transform_image_edit_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: FileTypes,
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, RequestFiles]:
|
||||
"""
|
||||
Transform OpenAI-style request to Black Forest Labs request format.
|
||||
|
||||
BFL uses JSON body with base64-encoded images, not multipart/form-data.
|
||||
"""
|
||||
# Read and encode image
|
||||
image_bytes = self._read_image_bytes(image)
|
||||
b64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
# Build request body
|
||||
request_body: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"input_image": b64_image,
|
||||
}
|
||||
|
||||
# Add optional params (only BFL-recognized parameters)
|
||||
bfl_request_params = [
|
||||
"seed", "output_format", "safety_tolerance", "prompt_upsampling",
|
||||
"aspect_ratio", "steps", "guidance", "grow_mask",
|
||||
"top", "bottom", "left", "right",
|
||||
]
|
||||
for key, value in image_edit_optional_request_params.items():
|
||||
if key in bfl_request_params and value is not None:
|
||||
request_body[key] = value
|
||||
|
||||
# Handle mask if provided (for inpainting)
|
||||
if "mask" in image_edit_optional_request_params:
|
||||
mask = image_edit_optional_request_params["mask"]
|
||||
mask_bytes = self._read_image_bytes(mask)
|
||||
request_body["mask"] = base64.b64encode(mask_bytes).decode("utf-8")
|
||||
|
||||
# BFL uses JSON, not multipart - return empty files
|
||||
return request_body, []
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform Black Forest Labs response to OpenAI-compatible ImageResponse.
|
||||
|
||||
This is called with the FINAL polled response (after handler does polling).
|
||||
The response contains: {"status": "Ready", "result": {"sample": "https://..."}}
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Error parsing BFL response: {e}",
|
||||
)
|
||||
|
||||
# Get image URL from result
|
||||
image_url = response_data.get("result", {}).get("sample")
|
||||
if not image_url:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No image URL in BFL result",
|
||||
)
|
||||
|
||||
# Build ImageResponse
|
||||
return ImageResponse(
|
||||
created=int(time.time()),
|
||||
data=[ImageObject(url=image_url)],
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BlackForestLabsError:
|
||||
"""Return the appropriate error class for Black Forest Labs."""
|
||||
return BlackForestLabsError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
from .handler import BlackForestLabsImageGeneration, bfl_image_generation
|
||||
from .transformation import (
|
||||
BlackForestLabsImageGenerationConfig,
|
||||
get_black_forest_labs_image_generation_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BlackForestLabsImageGenerationConfig",
|
||||
"get_black_forest_labs_image_generation_config",
|
||||
"BlackForestLabsImageGeneration",
|
||||
"bfl_image_generation",
|
||||
]
|
||||
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Black Forest Labs Image Generation Handler
|
||||
|
||||
Handles image generation requests for Black Forest Labs models.
|
||||
BFL uses an async polling pattern - the initial request returns a task ID,
|
||||
then we poll until the result is ready.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
from ..common_utils import (
|
||||
DEFAULT_MAX_POLLING_TIME,
|
||||
DEFAULT_POLLING_INTERVAL,
|
||||
BlackForestLabsError,
|
||||
)
|
||||
from .transformation import BlackForestLabsImageGenerationConfig
|
||||
|
||||
|
||||
class BlackForestLabsImageGeneration:
|
||||
"""
|
||||
Black Forest Labs Image Generation handler.
|
||||
|
||||
Handles the HTTP requests and polling logic, delegating data transformation
|
||||
to the BlackForestLabsImageGenerationConfig class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = BlackForestLabsImageGenerationConfig()
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: Dict,
|
||||
litellm_params: Union[GenericLiteLLMParams, Dict],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
aimg_generation: bool = False,
|
||||
) -> Union[ImageResponse, Any]:
|
||||
"""
|
||||
Main entry point for image generation requests.
|
||||
|
||||
Args:
|
||||
model: The model to use (e.g., "black_forest_labs/flux-pro-1.1")
|
||||
prompt: The text prompt for image generation
|
||||
model_response: ImageResponse object to populate
|
||||
optional_params: Optional parameters for the request
|
||||
litellm_params: LiteLLM parameters including api_key, api_base
|
||||
logging_obj: Logging object
|
||||
timeout: Request timeout
|
||||
extra_headers: Additional headers
|
||||
client: HTTP client to use
|
||||
aimg_generation: If True, return async coroutine
|
||||
|
||||
Returns:
|
||||
ImageResponse or coroutine if aimg_generation=True
|
||||
"""
|
||||
# Handle litellm_params as dict or object
|
||||
if isinstance(litellm_params, dict):
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_base = litellm_params.get("api_base")
|
||||
litellm_params_dict = litellm_params
|
||||
else:
|
||||
api_key = litellm_params.api_key
|
||||
api_base = litellm_params.api_base
|
||||
litellm_params_dict = dict(litellm_params)
|
||||
|
||||
if aimg_generation:
|
||||
return self.async_image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
client=client if isinstance(client, AsyncHTTPHandler) else None,
|
||||
)
|
||||
|
||||
# Sync version
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_client = _get_httpx_client()
|
||||
else:
|
||||
sync_client = client
|
||||
|
||||
# Validate environment and get headers
|
||||
headers = self.config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers={},
|
||||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# Get complete URL
|
||||
complete_url = self.config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
|
||||
# Transform request
|
||||
data = self.config.transform_image_generation_request(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Logging
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": complete_url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
# Make initial request
|
||||
try:
|
||||
response = sync_client.post(
|
||||
url=complete_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message=f"Request failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
final_response = self._poll_for_result_sync(
|
||||
initial_response=response,
|
||||
headers=headers,
|
||||
sync_client=sync_client,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
return self.config.transform_image_generation_response(
|
||||
model=model,
|
||||
raw_response=final_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async def async_image_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: Dict,
|
||||
litellm_params: Union[GenericLiteLLMParams, Dict],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Async version of image generation.
|
||||
"""
|
||||
# Handle litellm_params as dict or object
|
||||
if isinstance(litellm_params, dict):
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_base = litellm_params.get("api_base")
|
||||
litellm_params_dict = litellm_params
|
||||
else:
|
||||
api_key = litellm_params.api_key
|
||||
api_base = litellm_params.api_base
|
||||
litellm_params_dict = dict(litellm_params)
|
||||
|
||||
if client is None:
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BLACK_FOREST_LABS,
|
||||
)
|
||||
else:
|
||||
async_client = client
|
||||
|
||||
# Validate environment and get headers
|
||||
headers = self.config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers={},
|
||||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# Get complete URL
|
||||
complete_url = self.config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
|
||||
# Transform request
|
||||
data = self.config.transform_image_generation_request(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Logging
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": complete_url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
# Make initial request
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=complete_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message=f"Request failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
final_response = await self._poll_for_result_async(
|
||||
initial_response=response,
|
||||
headers=headers,
|
||||
async_client=async_client,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
return self.config.transform_image_generation_response(
|
||||
model=model,
|
||||
raw_response=final_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
def _poll_for_result_sync(
|
||||
self,
|
||||
initial_response: httpx.Response,
|
||||
headers: dict,
|
||||
sync_client: HTTPHandler,
|
||||
max_wait: float = DEFAULT_MAX_POLLING_TIME,
|
||||
interval: float = DEFAULT_POLLING_INTERVAL,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Poll BFL API until result is ready (sync version).
|
||||
"""
|
||||
# Validate initial response status code
|
||||
if initial_response.status_code >= 400:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL initial request failed: {initial_response.text}",
|
||||
)
|
||||
|
||||
# Parse initial response to get polling URL
|
||||
try:
|
||||
response_data = initial_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"Error parsing initial response: {e}",
|
||||
)
|
||||
|
||||
# Check for immediate errors
|
||||
if "errors" in response_data:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL error: {response_data['errors']}",
|
||||
)
|
||||
|
||||
polling_url = response_data.get("polling_url")
|
||||
if not polling_url:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
start_time = time.time()
|
||||
verbose_logger.debug(f"BFL starting sync polling at {polling_url}")
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
response = sync_client.get(
|
||||
url=polling_url,
|
||||
headers=polling_headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BlackForestLabsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Polling failed: {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
|
||||
verbose_logger.debug(f"BFL poll status: {status}")
|
||||
|
||||
if status == "Ready":
|
||||
return response
|
||||
elif status in ["Error", "Failed", "Content Moderated", "Request Moderated"]:
|
||||
raise BlackForestLabsError(
|
||||
status_code=400,
|
||||
message=f"Image generation failed: {status}",
|
||||
)
|
||||
|
||||
time.sleep(interval)
|
||||
|
||||
raise BlackForestLabsError(
|
||||
status_code=408,
|
||||
message=f"Polling timed out after {max_wait} seconds",
|
||||
)
|
||||
|
||||
async def _poll_for_result_async(
|
||||
self,
|
||||
initial_response: httpx.Response,
|
||||
headers: dict,
|
||||
async_client: AsyncHTTPHandler,
|
||||
max_wait: float = DEFAULT_MAX_POLLING_TIME,
|
||||
interval: float = DEFAULT_POLLING_INTERVAL,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Poll BFL API until result is ready (async version).
|
||||
"""
|
||||
# Validate initial response status code
|
||||
if initial_response.status_code >= 400:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL initial request failed: {initial_response.text}",
|
||||
)
|
||||
|
||||
# Parse initial response to get polling URL
|
||||
try:
|
||||
response_data = initial_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"Error parsing initial response: {e}",
|
||||
)
|
||||
|
||||
# Check for immediate errors
|
||||
if "errors" in response_data:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL error: {response_data['errors']}",
|
||||
)
|
||||
|
||||
polling_url = response_data.get("polling_url")
|
||||
if not polling_url:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
start_time = time.time()
|
||||
verbose_logger.debug(f"BFL starting async polling at {polling_url}")
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
response = await async_client.get(
|
||||
url=polling_url,
|
||||
headers=polling_headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BlackForestLabsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Polling failed: {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
|
||||
verbose_logger.debug(f"BFL poll status: {status}")
|
||||
|
||||
if status == "Ready":
|
||||
return response
|
||||
elif status in ["Error", "Failed", "Content Moderated", "Request Moderated"]:
|
||||
raise BlackForestLabsError(
|
||||
status_code=400,
|
||||
message=f"Image generation failed: {status}",
|
||||
)
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
raise BlackForestLabsError(
|
||||
status_code=408,
|
||||
message=f"Polling timed out after {max_wait} seconds",
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance for use in images/main.py
|
||||
bfl_image_generation = BlackForestLabsImageGeneration()
|
||||
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
Black Forest Labs Image Generation Configuration
|
||||
|
||||
Handles transformation between OpenAI-compatible format and Black Forest Labs API format
|
||||
for image generation endpoints (flux-pro-1.1, flux-pro-1.1-ultra, flux-dev, flux-pro).
|
||||
|
||||
API Reference: https://docs.bfl.ai/
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageGenerationOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
from ..common_utils import (
|
||||
DEFAULT_API_BASE,
|
||||
IMAGE_GENERATION_MODELS,
|
||||
BlackForestLabsError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BlackForestLabsImageGenerationConfig(BaseImageGenerationConfig):
|
||||
"""
|
||||
Configuration for Black Forest Labs image generation (text-to-image).
|
||||
|
||||
Supports:
|
||||
- flux-pro-1.1: Fast & reliable standard generation
|
||||
- flux-pro-1.1-ultra: Ultra high-resolution (up to 4MP)
|
||||
- flux-dev: Development/open-source variant
|
||||
- flux-pro: Original pro model
|
||||
|
||||
Note: HTTP requests and polling are handled by the handler (handler.py).
|
||||
This class only handles data transformation.
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Return list of OpenAI params supported by Black Forest Labs.
|
||||
|
||||
Note: BFL uses different parameter names, these are mapped in map_openai_params.
|
||||
"""
|
||||
return [
|
||||
"n", # Number of images (BFL returns 1 per request, but ultra supports up to 4)
|
||||
"size", # Maps to width/height or aspect_ratio
|
||||
"quality", # Maps to raw mode for ultra
|
||||
"seed",
|
||||
"output_format",
|
||||
"safety_tolerance",
|
||||
"prompt_upsampling",
|
||||
"raw",
|
||||
"num_images",
|
||||
"image_url",
|
||||
"image_prompt_strength",
|
||||
"aspect_ratio",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to Black Forest Labs parameters.
|
||||
|
||||
BFL-specific params are passed through directly.
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
for k, v in non_default_params.items():
|
||||
if k in optional_params:
|
||||
continue
|
||||
|
||||
if k in supported_params:
|
||||
# Map OpenAI 'size' to BFL width/height
|
||||
if k == "size" and v:
|
||||
self._map_size_param(v, optional_params)
|
||||
elif k == "n":
|
||||
if "ultra" in model.lower():
|
||||
optional_params["num_images"] = v
|
||||
# non-ultra: silently skip (n=1 is BFL default)
|
||||
elif k == "quality":
|
||||
if v == "hd" and "ultra" in model.lower():
|
||||
optional_params["raw"] = True
|
||||
# other quality values have no BFL mapping
|
||||
else:
|
||||
optional_params[k] = v
|
||||
elif not drop_params:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. "
|
||||
f"Supported parameters are {supported_params}. "
|
||||
f"Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_size_param(self, size: str, optional_params: dict) -> None:
|
||||
"""Map OpenAI size parameter to BFL width/height."""
|
||||
# Common size mappings
|
||||
size_mapping = {
|
||||
"1024x1024": (1024, 1024),
|
||||
"1792x1024": (1792, 1024),
|
||||
"1024x1792": (1024, 1792),
|
||||
"512x512": (512, 512),
|
||||
"256x256": (256, 256),
|
||||
}
|
||||
|
||||
if size in size_mapping:
|
||||
width, height = size_mapping[size]
|
||||
optional_params["width"] = width
|
||||
optional_params["height"] = height
|
||||
elif "x" in size:
|
||||
# Parse custom size
|
||||
try:
|
||||
width, height = map(int, size.lower().split("x"))
|
||||
optional_params["width"] = width
|
||||
optional_params["height"] = height
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid size format: '{size}'. Expected format 'WIDTHxHEIGHT' (e.g., '1024x1024')."
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and set up headers for Black Forest Labs.
|
||||
|
||||
BFL uses x-key header for authentication.
|
||||
"""
|
||||
final_api_key: Optional[str] = (
|
||||
api_key
|
||||
or get_secret_str("BFL_API_KEY")
|
||||
or get_secret_str("BLACK_FOREST_LABS_API_KEY")
|
||||
)
|
||||
|
||||
if not final_api_key:
|
||||
raise BlackForestLabsError(
|
||||
status_code=401,
|
||||
message="BFL_API_KEY is not set. Please set it via environment variable or pass api_key parameter.",
|
||||
)
|
||||
|
||||
headers["x-key"] = final_api_key
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["Accept"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def _get_model_endpoint(self, model: str) -> str:
|
||||
"""
|
||||
Get the API endpoint for a given model.
|
||||
"""
|
||||
# Remove provider prefix if present (e.g., "black_forest_labs/flux-pro-1.1")
|
||||
model_name = model.lower()
|
||||
if "/" in model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
# Check if model is in our mapping
|
||||
if model_name in IMAGE_GENERATION_MODELS:
|
||||
return IMAGE_GENERATION_MODELS[model_name]
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown BFL image generation model: {model_name}. "
|
||||
f"Supported models: {list(IMAGE_GENERATION_MODELS.keys())}"
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the Black Forest Labs API request.
|
||||
"""
|
||||
base_url: str = (
|
||||
api_base or get_secret_str("BFL_API_BASE") or DEFAULT_API_BASE
|
||||
)
|
||||
base_url = base_url.rstrip("/")
|
||||
|
||||
endpoint = self._get_model_endpoint(model)
|
||||
return f"{base_url}{endpoint}"
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform OpenAI-style request to Black Forest Labs request format.
|
||||
|
||||
https://docs.bfl.ai/flux_models/flux_1_1_pro
|
||||
"""
|
||||
# Build request body with prompt
|
||||
request_body: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
# BFL-specific params that can be passed through
|
||||
bfl_params = [
|
||||
"width",
|
||||
"height",
|
||||
"aspect_ratio",
|
||||
"seed",
|
||||
"output_format",
|
||||
"safety_tolerance",
|
||||
"prompt_upsampling",
|
||||
# Ultra-specific
|
||||
"raw",
|
||||
"num_images",
|
||||
"image_url",
|
||||
"image_prompt_strength",
|
||||
]
|
||||
|
||||
for param in bfl_params:
|
||||
if param in optional_params and optional_params[param] is not None:
|
||||
request_body[param] = optional_params[param]
|
||||
|
||||
# Set default output format if not specified
|
||||
if "output_format" not in request_body:
|
||||
request_body["output_format"] = "png"
|
||||
|
||||
return request_body
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
**kwargs,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform Black Forest Labs response to OpenAI-compatible ImageResponse.
|
||||
|
||||
This is called with the FINAL polled response (after handler does polling).
|
||||
The response contains: {"status": "Ready", "result": {"sample": "https://..."}}
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Error parsing BFL response: {e}",
|
||||
)
|
||||
|
||||
result = response_data.get("result", {})
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Handle single image (sample) or multiple images
|
||||
if isinstance(result, dict) and "sample" in result:
|
||||
model_response.data.append(ImageObject(url=result["sample"]))
|
||||
elif isinstance(result, list):
|
||||
# Multiple images returned
|
||||
for img in result:
|
||||
if isinstance(img, str):
|
||||
model_response.data.append(ImageObject(url=img))
|
||||
elif isinstance(img, dict) and "url" in img:
|
||||
model_response.data.append(ImageObject(url=img["url"]))
|
||||
|
||||
if not model_response.data:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No image URL in BFL result",
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BlackForestLabsError:
|
||||
"""Return the appropriate error class for Black Forest Labs."""
|
||||
return BlackForestLabsError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
)
|
||||
|
||||
|
||||
def get_black_forest_labs_image_generation_config(
|
||||
model: str,
|
||||
) -> BlackForestLabsImageGenerationConfig:
|
||||
"""
|
||||
Get the appropriate image generation config for a Black Forest Labs model.
|
||||
|
||||
Currently returns a single config class, but can be extended
|
||||
for model-specific configurations if needed.
|
||||
"""
|
||||
return BlackForestLabsImageGenerationConfig()
|
||||
@@ -429,11 +429,8 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
||||
"FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
|
||||
)
|
||||
|
||||
base = api_base.rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[: -len("/v1")]
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{base}/v1/accounts/{account_id}/models",
|
||||
url=f"{api_base}/v1/accounts/{account_id}/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Support for Mistral Voxtral audio transcription via ``/v1/audio/transcriptions``.
|
||||
|
||||
API reference: https://docs.mistral.ai/api/#tag/audio/operation/audio_transcriptions_v1_audio_transcriptions_post
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.audio_utils.utils import process_audio_file
|
||||
from litellm.llms.base_llm.audio_transcription.transformation import (
|
||||
AudioTranscriptionRequestData,
|
||||
BaseAudioTranscriptionConfig,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIAudioTranscriptionOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import FileTypes, TranscriptionResponse
|
||||
|
||||
|
||||
class MistralAudioTranscriptionException(BaseLLMException):
|
||||
pass
|
||||
|
||||
|
||||
class MistralAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIAudioTranscriptionOptionalParams]:
|
||||
return [
|
||||
"language",
|
||||
"temperature",
|
||||
"timestamp_granularities",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
for k, v in non_default_params.items():
|
||||
if k in supported_params:
|
||||
optional_params[k] = v
|
||||
return optional_params
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
api_base = (
|
||||
"https://api.mistral.ai/v1"
|
||||
if api_base is None
|
||||
else api_base.rstrip("/")
|
||||
)
|
||||
return f"{api_base}/audio/transcriptions"
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return MistralAudioTranscriptionException(
|
||||
message=error_message,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("MISTRAL_API_KEY")
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"accept": "application/json",
|
||||
}
|
||||
default_headers.update(headers or {})
|
||||
return default_headers
|
||||
|
||||
def transform_audio_transcription_request(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> AudioTranscriptionRequestData:
|
||||
processed_audio = process_audio_file(audio_file)
|
||||
|
||||
form_fields: dict = {
|
||||
"model": model,
|
||||
}
|
||||
|
||||
# OpenAI-compatible params
|
||||
for key in self.get_supported_openai_params(model):
|
||||
value = optional_params.get(key)
|
||||
if value is not None:
|
||||
form_fields[key] = value
|
||||
|
||||
# Mistral-specific params (e.g. diarize)
|
||||
provider_specific_params = self.get_provider_specific_params(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
openai_params=self.get_supported_openai_params(model),
|
||||
)
|
||||
for key, value in provider_specific_params.items():
|
||||
form_fields[key] = str(value).lower() if isinstance(value, bool) else str(value)
|
||||
|
||||
files = {
|
||||
"file": (
|
||||
processed_audio.filename,
|
||||
processed_audio.file_content,
|
||||
processed_audio.content_type,
|
||||
)
|
||||
}
|
||||
|
||||
return AudioTranscriptionRequestData(data=form_fields, files=files)
|
||||
|
||||
def transform_audio_transcription_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
) -> TranscriptionResponse:
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise MistralAudioTranscriptionException(
|
||||
message=raw_response.text,
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
text = response_json.get("text") or ""
|
||||
response = TranscriptionResponse(text=text)
|
||||
response._hidden_params = response_json
|
||||
return response
|
||||
@@ -25,22 +25,6 @@ def _normalize_reasoning_effort_for_chat_completion(
|
||||
return None
|
||||
|
||||
|
||||
def _get_effort_level(value: Union[str, dict, None]) -> Optional[str]:
|
||||
"""Extract the effective effort level from reasoning_effort (string or dict).
|
||||
|
||||
Use this for guards that compare effort level (e.g. xhigh validation, "none" checks).
|
||||
Ensures dict inputs like {"effort": "none", "summary": "detailed"} are correctly
|
||||
treated as effort="none" for validation purposes.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict) and "effort" in value:
|
||||
return value["effort"]
|
||||
return None
|
||||
|
||||
|
||||
class OpenAIGPT5Config(OpenAIGPTConfig):
|
||||
"""Configuration for gpt-5 models including GPT-5-Codex variants.
|
||||
|
||||
@@ -86,19 +70,6 @@ class OpenAIGPT5Config(OpenAIGPTConfig):
|
||||
model_name = model.split("/")[-1]
|
||||
return model_name.startswith("gpt-5.4")
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_4_plus_model(cls, model: str) -> bool:
|
||||
"""Check if the model is gpt-5.4 or newer (5.4, 5.5, 5.6, etc., including pro)."""
|
||||
model_name = model.split("/")[-1]
|
||||
if not model_name.startswith("gpt-5."):
|
||||
return False
|
||||
try:
|
||||
version_str = model_name.replace("gpt-5.", "").split("-")[0]
|
||||
major = version_str.split(".")[0]
|
||||
return int(major) >= 4
|
||||
except (ValueError, IndexError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _supports_reasoning_effort_level(cls, model: str, level: str) -> bool:
|
||||
"""Check if the model supports a specific reasoning_effort level.
|
||||
@@ -179,35 +150,21 @@ class OpenAIGPT5Config(OpenAIGPTConfig):
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
# Get raw reasoning_effort and effective effort level for all guards.
|
||||
# Use effective_effort (extracted string) for xhigh validation, "none" checks, and
|
||||
# tool/sampling guards — dict inputs like {"effort": "none", "summary": "detailed"}
|
||||
# must be treated as effort="none" to avoid incorrect tool-drop or sampling errors.
|
||||
raw_reasoning_effort = non_default_params.get(
|
||||
"reasoning_effort"
|
||||
) or optional_params.get("reasoning_effort")
|
||||
effective_effort = _get_effort_level(raw_reasoning_effort)
|
||||
|
||||
# Normalize to string for Chat Completions API when dict has only "effort".
|
||||
# Preserve full dict (e.g. {"effort": "high", "summary": "detailed"}) for Responses API.
|
||||
if isinstance(raw_reasoning_effort, dict) and set(
|
||||
raw_reasoning_effort.keys()
|
||||
) <= {"effort"}:
|
||||
normalized = _normalize_reasoning_effort_for_chat_completion(
|
||||
raw_reasoning_effort
|
||||
)
|
||||
if normalized is not None:
|
||||
if "reasoning_effort" in non_default_params:
|
||||
non_default_params["reasoning_effort"] = normalized
|
||||
if "reasoning_effort" in optional_params:
|
||||
optional_params["reasoning_effort"] = normalized
|
||||
|
||||
reasoning_effort = (
|
||||
# Normalize reasoning_effort: chat completion API expects a string, not a dict
|
||||
# (e.g. {'effort': 'high', 'summary': 'detailed'} -> 'high')
|
||||
raw_reasoning_effort = (
|
||||
non_default_params.get("reasoning_effort")
|
||||
or optional_params.get("reasoning_effort")
|
||||
or raw_reasoning_effort
|
||||
)
|
||||
if effective_effort is not None and effective_effort == "xhigh":
|
||||
normalized = _normalize_reasoning_effort_for_chat_completion(raw_reasoning_effort)
|
||||
if raw_reasoning_effort is not None and normalized is not None:
|
||||
if "reasoning_effort" in non_default_params:
|
||||
non_default_params["reasoning_effort"] = normalized
|
||||
if "reasoning_effort" in optional_params:
|
||||
optional_params["reasoning_effort"] = normalized
|
||||
|
||||
reasoning_effort = normalized or raw_reasoning_effort
|
||||
if reasoning_effort is not None and reasoning_effort == "xhigh":
|
||||
if not self._supports_reasoning_effort_level(model, "xhigh"):
|
||||
if litellm.drop_params or drop_params:
|
||||
non_default_params.pop("reasoning_effort", None)
|
||||
@@ -234,20 +191,17 @@ class OpenAIGPT5Config(OpenAIGPTConfig):
|
||||
has_tools = bool(
|
||||
non_default_params.get("tools") or optional_params.get("tools")
|
||||
)
|
||||
if has_tools and effective_effort not in (None, "none"):
|
||||
# Check if this will be routed to Responses API
|
||||
# If so, keep reasoning_effort; otherwise drop it for chat completions API
|
||||
if not self.is_model_gpt_5_4_plus_model(model):
|
||||
non_default_params.pop("reasoning_effort", None)
|
||||
optional_params.pop("reasoning_effort", None)
|
||||
reasoning_effort = None
|
||||
if has_tools and reasoning_effort not in (None, "none"):
|
||||
non_default_params.pop("reasoning_effort", None)
|
||||
optional_params.pop("reasoning_effort", None)
|
||||
reasoning_effort = None
|
||||
|
||||
# gpt-5.1/5.2 support logprobs, top_p, top_logprobs only when reasoning_effort="none"
|
||||
supports_none = self._supports_reasoning_effort_level(model, "none")
|
||||
if supports_none:
|
||||
sampling_params = ["logprobs", "top_logprobs", "top_p"]
|
||||
has_sampling = any(p in non_default_params for p in sampling_params)
|
||||
if has_sampling and effective_effort not in (None, "none"):
|
||||
if has_sampling and reasoning_effort not in (None, "none"):
|
||||
if litellm.drop_params or drop_params:
|
||||
for p in sampling_params:
|
||||
non_default_params.pop(p, None)
|
||||
@@ -257,7 +211,7 @@ class OpenAIGPT5Config(OpenAIGPTConfig):
|
||||
"gpt-5.1/5.2/5.4 only support logprobs, top_p, top_logprobs when "
|
||||
"reasoning_effort='none'. Current reasoning_effort='{}'. "
|
||||
"To drop unsupported params set `litellm.drop_params = True`"
|
||||
).format(effective_effort),
|
||||
).format(reasoning_effort),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
@@ -265,9 +219,7 @@ class OpenAIGPT5Config(OpenAIGPTConfig):
|
||||
temperature_value: Optional[float] = non_default_params.pop("temperature")
|
||||
if temperature_value is not None:
|
||||
# models supporting reasoning_effort="none" also support flexible temperature
|
||||
if supports_none and (
|
||||
effective_effort == "none" or effective_effort is None
|
||||
):
|
||||
if supports_none and (reasoning_effort == "none" or reasoning_effort is None):
|
||||
optional_params["temperature"] = temperature_value
|
||||
elif temperature_value == 1:
|
||||
optional_params["temperature"] = temperature_value
|
||||
|
||||
@@ -58,6 +58,7 @@ from ..common_utils import OpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.types.llms.openai import ChatCompletionToolParam
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
@@ -759,6 +760,13 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||
def get_base_model(model: Optional[str] = None) -> Optional[str]:
|
||||
return model
|
||||
|
||||
def get_token_counter(self) -> Optional["BaseTokenCounter"]:
|
||||
from litellm.llms.openai.responses.count_tokens.token_counter import (
|
||||
OpenAITokenCounter,
|
||||
)
|
||||
|
||||
return OpenAITokenCounter()
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
|
||||
@@ -40,6 +40,7 @@ class OpenAIImageEditConfig(BaseImageEditConfig):
|
||||
"image",
|
||||
"prompt",
|
||||
"background",
|
||||
"input_fidelity",
|
||||
"mask",
|
||||
"model",
|
||||
"n",
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
OpenAI Responses API token counting implementation.
|
||||
"""
|
||||
|
||||
from litellm.llms.openai.responses.count_tokens.handler import (
|
||||
OpenAICountTokensHandler,
|
||||
)
|
||||
from litellm.llms.openai.responses.count_tokens.token_counter import (
|
||||
OpenAITokenCounter,
|
||||
)
|
||||
from litellm.llms.openai.responses.count_tokens.transformation import (
|
||||
OpenAICountTokensConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OpenAICountTokensHandler",
|
||||
"OpenAICountTokensConfig",
|
||||
"OpenAITokenCounter",
|
||||
]
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
OpenAI Responses API token counting handler.
|
||||
|
||||
Uses httpx for HTTP requests to OpenAI's /v1/responses/input_tokens endpoint.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.llms.openai.common_utils import OpenAIError
|
||||
from litellm.llms.openai.responses.count_tokens.transformation import (
|
||||
OpenAICountTokensConfig,
|
||||
)
|
||||
|
||||
|
||||
class OpenAICountTokensHandler(OpenAICountTokensConfig):
|
||||
"""
|
||||
Handler for OpenAI Responses API token counting requests.
|
||||
"""
|
||||
|
||||
async def handle_count_tokens_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, List[Any]],
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
instructions: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle a token counting request to OpenAI's Responses API.
|
||||
|
||||
Returns:
|
||||
Dictionary containing {"input_tokens": <number>}
|
||||
|
||||
Raises:
|
||||
OpenAIError: If the API request fails
|
||||
"""
|
||||
try:
|
||||
self.validate_request(model, input)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Processing OpenAI CountTokens request for model: {model}"
|
||||
)
|
||||
|
||||
request_body = self.transform_request_to_count_tokens(
|
||||
model=model,
|
||||
input=input,
|
||||
tools=tools,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
endpoint_url = self.get_openai_count_tokens_endpoint(api_base)
|
||||
|
||||
verbose_logger.debug(f"Making request to: {endpoint_url}")
|
||||
|
||||
headers = self.get_required_headers(api_key)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.OPENAI
|
||||
)
|
||||
|
||||
request_timeout = timeout if timeout is not None else litellm.request_timeout
|
||||
|
||||
response = await async_client.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=request_body,
|
||||
timeout=request_timeout,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Response status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
verbose_logger.error(f"OpenAI API error: {error_text}")
|
||||
raise OpenAIError(
|
||||
status_code=response.status_code,
|
||||
message=error_text,
|
||||
)
|
||||
|
||||
openai_response = response.json()
|
||||
verbose_logger.debug(f"OpenAI response: {openai_response}")
|
||||
return openai_response
|
||||
|
||||
except OpenAIError:
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
|
||||
raise OpenAIError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except (httpx.RequestError, json.JSONDecodeError, ValueError) as e:
|
||||
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
|
||||
raise OpenAIError(
|
||||
status_code=500,
|
||||
message=f"CountTokens processing error: {str(e)}",
|
||||
)
|
||||
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
OpenAI Token Counter implementation using the Responses API /input_tokens endpoint.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.llms.openai.common_utils import OpenAIError
|
||||
from litellm.llms.openai.responses.count_tokens.handler import (
|
||||
OpenAICountTokensHandler,
|
||||
)
|
||||
from litellm.llms.openai.responses.count_tokens.transformation import (
|
||||
OpenAICountTokensConfig,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders, TokenCountResponse
|
||||
|
||||
# Global handler instance - reuse across all token counting requests
|
||||
openai_count_tokens_handler = OpenAICountTokensHandler()
|
||||
|
||||
|
||||
class OpenAITokenCounter(BaseTokenCounter):
|
||||
"""Token counter implementation for OpenAI provider using the Responses API."""
|
||||
|
||||
def should_use_token_counting_api(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
return custom_llm_provider == LlmProviders.OPENAI.value
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model_to_use: str,
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
contents: Optional[List[Dict[str, Any]]],
|
||||
deployment: Optional[Dict[str, Any]] = None,
|
||||
request_model: str = "",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Optional[TokenCountResponse]:
|
||||
"""
|
||||
Count tokens using OpenAI's Responses API /input_tokens endpoint.
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
deployment = deployment or {}
|
||||
litellm_params = deployment.get("litellm_params", {})
|
||||
|
||||
# Get OpenAI API key from deployment config or environment
|
||||
api_key = litellm_params.get("api_key")
|
||||
if not api_key:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
verbose_logger.warning("No OpenAI API key found for token counting")
|
||||
return None
|
||||
|
||||
api_base = litellm_params.get("api_base")
|
||||
|
||||
# Convert chat messages to Responses API input format
|
||||
input_items, instructions = OpenAICountTokensConfig.messages_to_responses_input(
|
||||
messages
|
||||
)
|
||||
|
||||
# Use system param if instructions not extracted from messages
|
||||
if instructions is None and system is not None:
|
||||
instructions = system if isinstance(system, str) else str(system)
|
||||
|
||||
# If no input items were produced (e.g., system-only messages), fall back to local counting
|
||||
if not input_items:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await openai_count_tokens_handler.handle_count_tokens_request(
|
||||
model=model_to_use,
|
||||
input=input_items if input_items is not None else [],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
tools=tools,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return TokenCountResponse(
|
||||
total_tokens=result.get("input_tokens", 0),
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="openai_api",
|
||||
original_response=result,
|
||||
)
|
||||
except OpenAIError as e:
|
||||
verbose_logger.warning(
|
||||
f"OpenAI CountTokens API error: status={e.status_code}, message={e.message}"
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="openai_api",
|
||||
error=True,
|
||||
error_message=e.message,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Error calling OpenAI CountTokens API: {e}")
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="openai_api",
|
||||
error=True,
|
||||
error_message=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
OpenAI Responses API token counting transformation logic.
|
||||
|
||||
This module handles the transformation of requests to OpenAI's /v1/responses/input_tokens endpoint.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class OpenAICountTokensConfig:
|
||||
"""
|
||||
Configuration and transformation logic for OpenAI Responses API token counting.
|
||||
|
||||
OpenAI Responses API Token Counting Specification:
|
||||
- Endpoint: POST https://api.openai.com/v1/responses/input_tokens
|
||||
- Response: {"input_tokens": <number>}
|
||||
"""
|
||||
|
||||
def get_openai_count_tokens_endpoint(self, api_base: Optional[str] = None) -> str:
|
||||
base = api_base or "https://api.openai.com/v1"
|
||||
base = base.rstrip("/")
|
||||
return f"{base}/responses/input_tokens"
|
||||
|
||||
def transform_request_to_count_tokens(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, List[Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
instructions: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform request to OpenAI Responses API token counting format.
|
||||
|
||||
The Responses API uses `input` (not `messages`) and `instructions` (not `system`).
|
||||
"""
|
||||
request: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"input": input,
|
||||
}
|
||||
|
||||
if instructions is not None:
|
||||
request["instructions"] = instructions
|
||||
|
||||
if tools is not None:
|
||||
request["tools"] = self._transform_tools_for_responses_api(tools)
|
||||
|
||||
return request
|
||||
|
||||
def get_required_headers(self, api_key: str) -> Dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
def validate_request(
|
||||
self, model: str, input: Union[str, List[Any]]
|
||||
) -> None:
|
||||
if not model:
|
||||
raise ValueError("model parameter is required")
|
||||
|
||||
if not input:
|
||||
raise ValueError("input parameter is required")
|
||||
|
||||
@staticmethod
|
||||
def _transform_tools_for_responses_api(
|
||||
tools: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transform OpenAI chat tools format to Responses API tools format.
|
||||
|
||||
Chat format: {"type": "function", "function": {"name": "...", "parameters": {...}}}
|
||||
Responses format: {"type": "function", "name": "...", "parameters": {...}}
|
||||
"""
|
||||
transformed = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function" and "function" in tool:
|
||||
func = tool["function"]
|
||||
item: Dict[str, Any] = {
|
||||
"type": "function",
|
||||
"name": func.get("name", ""),
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
if "strict" in func:
|
||||
item["strict"] = func["strict"]
|
||||
transformed.append(item)
|
||||
else:
|
||||
# Pass through non-function tools (e.g., web_search, file_search)
|
||||
transformed.append(tool)
|
||||
return transformed
|
||||
|
||||
@staticmethod
|
||||
def messages_to_responses_input(
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> tuple:
|
||||
"""
|
||||
Convert standard chat messages format to OpenAI Responses API input format.
|
||||
|
||||
Returns:
|
||||
(input_items, instructions) tuple where instructions is extracted
|
||||
from system/developer messages.
|
||||
"""
|
||||
input_items: List[Dict[str, Any]] = []
|
||||
instructions_parts: List[str] = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content") or ""
|
||||
|
||||
if role in ("system", "developer"):
|
||||
# Extract system/developer messages as instructions
|
||||
if isinstance(content, str):
|
||||
instructions_parts.append(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle content blocks - extract text
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
instructions_parts.append("\n".join(text_parts))
|
||||
elif role == "user":
|
||||
if isinstance(content, list):
|
||||
# Extract text from content blocks for Responses API
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
content = "\n".join(text_parts)
|
||||
input_items.append({"role": "user", "content": content})
|
||||
elif role == "assistant":
|
||||
# Map tool_calls to Responses API function_call items
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if content:
|
||||
input_items.append({"role": "assistant", "content": content})
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
input_items.append({
|
||||
"type": "function_call",
|
||||
"call_id": tc.get("id", ""),
|
||||
"name": func.get("name", ""),
|
||||
"arguments": func.get("arguments", ""),
|
||||
})
|
||||
elif not content:
|
||||
input_items.append({"role": "assistant", "content": content})
|
||||
elif role == "tool":
|
||||
input_items.append({
|
||||
"type": "function_call_output",
|
||||
"call_id": msg.get("tool_call_id", ""),
|
||||
"output": content if isinstance(content, str) else str(content),
|
||||
})
|
||||
|
||||
instructions = "\n".join(instructions_parts) if instructions_parts else None
|
||||
return input_items, instructions
|
||||
@@ -10,8 +10,9 @@ Instead of creating a full Python module for simple OpenAI-compatible providers,
|
||||
|
||||
- `providers.json` - Configuration file for all JSON-based providers
|
||||
- `json_loader.py` - Loads and parses the JSON configuration
|
||||
- `dynamic_config.py` - Generates Python config classes from JSON
|
||||
- `chat/` - Existing OpenAI-like chat completion handlers
|
||||
- `dynamic_config.py` - Generates Python config classes from JSON (chat + responses)
|
||||
- `chat/` - OpenAI-like chat completion handlers
|
||||
- `responses/` - OpenAI-like Responses API handlers
|
||||
|
||||
## Adding a New Provider
|
||||
|
||||
@@ -96,6 +97,32 @@ response = litellm.completion(
|
||||
)
|
||||
```
|
||||
|
||||
## Responses API Support
|
||||
|
||||
Providers that support the OpenAI Responses API (`/v1/responses`) can declare it via `supported_endpoints`:
|
||||
|
||||
```json
|
||||
{
|
||||
"your_provider": {
|
||||
"base_url": "https://api.yourprovider.com/v1",
|
||||
"api_key_env": "YOUR_PROVIDER_API_KEY",
|
||||
"supported_endpoints": ["/v1/chat/completions", "/v1/responses"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This enables `litellm.responses(model="your_provider/model-name", ...)` with zero Python code.
|
||||
The provider inherits all request/response handling from OpenAI's Responses API config.
|
||||
|
||||
If `supported_endpoints` is omitted, it defaults to `[]` (only chat completions, which is always enabled for JSON providers).
|
||||
|
||||
### How It Works
|
||||
|
||||
1. `json_loader.py` checks `supported_endpoints` for `/v1/responses`
|
||||
2. `dynamic_config.py` generates a responses config class (inherits from `OpenAIResponsesAPIConfig`)
|
||||
3. `ProviderConfigManager.get_provider_responses_api_config()` returns the generated config
|
||||
4. Request/response transformation is inherited from OpenAI — no custom code needed
|
||||
|
||||
## Benefits
|
||||
|
||||
- **Simple**: 2-5 lines of JSON vs 100+ lines of Python
|
||||
@@ -112,6 +139,10 @@ Use a Python config class if you need:
|
||||
- Provider-specific streaming logic
|
||||
- Advanced tool calling transformations
|
||||
|
||||
For providers that are *mostly* OpenAI-compatible but need small overrides (e.g. preset model handling),
|
||||
you can inherit from `OpenAIResponsesAPIConfig` and override only what's needed — see
|
||||
`litellm/llms/perplexity/responses/transformation.py` for a minimal example (~40 lines).
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### How It Works
|
||||
@@ -125,5 +156,6 @@ Use a Python config class if you need:
|
||||
|
||||
The JSON system is integrated at:
|
||||
- `litellm/litellm_core_utils/get_llm_provider_logic.py` - Provider resolution
|
||||
- `litellm/utils.py` - ProviderConfigManager
|
||||
- `litellm/utils.py` - ProviderConfigManager (chat + responses)
|
||||
- `litellm/responses/main.py` - Responses API routing
|
||||
- `litellm/constants.py` - openai_compatible_providers list
|
||||
|
||||
@@ -172,3 +172,63 @@ def create_config_class(provider: SimpleProviderConfig):
|
||||
return provider.slug
|
||||
|
||||
return JSONProviderConfig
|
||||
|
||||
|
||||
_responses_config_cache: dict = {}
|
||||
|
||||
|
||||
def create_responses_config_class(provider: SimpleProviderConfig):
|
||||
"""Generate a Responses API config class dynamically from JSON configuration.
|
||||
|
||||
Parallel to create_config_class() but for /v1/responses endpoints.
|
||||
Classes are cached per provider slug to avoid regeneration on every request.
|
||||
"""
|
||||
if provider.slug in _responses_config_cache:
|
||||
return _responses_config_cache[provider.slug]
|
||||
|
||||
from litellm.llms.openai_like.responses.transformation import (
|
||||
OpenAILikeResponsesConfig,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
class JSONProviderResponsesConfig(OpenAILikeResponsesConfig):
|
||||
@property
|
||||
def custom_llm_provider(self): # type: ignore[override]
|
||||
return provider.slug
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
litellm_params: Optional[GenericLiteLLMParams],
|
||||
) -> dict:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or get_secret_str(provider.api_key_env)
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
if not api_base:
|
||||
if provider.api_base_env:
|
||||
api_base = get_secret_str(provider.api_base_env)
|
||||
if not api_base:
|
||||
api_base = provider.base_url
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for provider {provider.slug}"
|
||||
)
|
||||
|
||||
api_base = api_base.rstrip("/")
|
||||
return f"{api_base}/responses"
|
||||
|
||||
_responses_config_cache[provider.slug] = JSONProviderResponsesConfig
|
||||
return JSONProviderResponsesConfig
|
||||
|
||||
@@ -21,6 +21,7 @@ class SimpleProviderConfig:
|
||||
self.param_mappings = data.get("param_mappings", {})
|
||||
self.constraints = data.get("constraints", {})
|
||||
self.special_handling = data.get("special_handling", {})
|
||||
self.supported_endpoints = data.get("supported_endpoints", [])
|
||||
|
||||
|
||||
class JSONProviderRegistry:
|
||||
@@ -66,6 +67,14 @@ class JSONProviderRegistry:
|
||||
"""Check if a provider is defined via JSON"""
|
||||
return slug in cls._providers
|
||||
|
||||
@classmethod
|
||||
def supports_responses_api(cls, slug: str) -> bool:
|
||||
"""Check if a JSON provider supports the Responses API"""
|
||||
provider = cls._providers.get(slug)
|
||||
if provider is None:
|
||||
return False
|
||||
return "/v1/responses" in provider.supported_endpoints
|
||||
|
||||
@classmethod
|
||||
def list_providers(cls) -> list:
|
||||
"""List all registered provider slugs"""
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
from litellm.llms.openai_like.responses.transformation import (
|
||||
OpenAILikeResponsesConfig,
|
||||
)
|
||||
|
||||
__all__ = ["OpenAILikeResponsesConfig"]
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
OpenAI-like Responses API transformation.
|
||||
|
||||
Base class for JSON-declared providers that support the /v1/responses endpoint.
|
||||
Inherits everything from OpenAIResponsesAPIConfig; subclasses only override
|
||||
provider-specific resolution (slug, API key env var, base URL).
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
|
||||
class OpenAILikeResponsesConfig(OpenAIResponsesAPIConfig):
|
||||
"""
|
||||
Responses API config for OpenAI-compatible providers declared via JSON.
|
||||
|
||||
Concrete per-provider classes are generated dynamically in dynamic_config.py.
|
||||
This base provides the three overridable hooks that the dynamic generator
|
||||
fills in: custom_llm_provider, validate_environment, get_complete_url.
|
||||
"""
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Union[str, LlmProviders]: # type: ignore[override]
|
||||
return "openai_like"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
litellm_params: Optional[GenericLiteLLMParams],
|
||||
) -> dict:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = litellm_params.api_key or get_secret_str("OPENAI_LIKE_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE")
|
||||
if not api_base:
|
||||
raise ValueError("api_base is required for openai_like provider")
|
||||
api_base = api_base.rstrip("/")
|
||||
return f"{api_base}/responses"
|
||||
@@ -1,54 +1,31 @@
|
||||
"""
|
||||
Transformation logic for Perplexity Agent API (Responses API)
|
||||
Perplexity Responses API — OpenAI-compatible.
|
||||
|
||||
This module handles the translation between OpenAI's Responses API format
|
||||
and Perplexity's Responses API format, which supports:
|
||||
- Third-party model access (OpenAI, Anthropic, Google, xAI, etc.)
|
||||
- Presets for optimized configurations
|
||||
- Web search and URL fetching tools
|
||||
- Reasoning effort control
|
||||
- Instructions parameter for system-level guidance
|
||||
The only provider quirks:
|
||||
- cost returned as dict → handled by ResponseAPIUsage.parse_cost validator
|
||||
- preset models (preset/pro-search) → handled by transform_responses_api_request
|
||||
- HTTP 200 with status:"failed" → raised as exception in transform_response_api_response
|
||||
|
||||
Ref: https://docs.perplexity.ai/api-reference/responses-post
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
ResponseAPIUsage,
|
||||
ResponseInputParam,
|
||||
ResponsesAPIOptionalRequestParams,
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStreamingResponse,
|
||||
)
|
||||
from litellm.types.llms.openai import ResponseInputParam, ResponsesAPIResponse
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
|
||||
class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
|
||||
"""
|
||||
Configuration for Perplexity Agent API (Responses API)
|
||||
|
||||
|
||||
Reference: https://docs.perplexity.ai/docs/agent-api/overview
|
||||
"""
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.PERPLEXITY
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Perplexity Responses API supports a different set of parameters
|
||||
|
||||
Ref: https://docs.perplexity.ai/api-reference/responses-post
|
||||
Params aligned with response-echo fields and Open Responses spec.
|
||||
"""
|
||||
"""Ref: https://docs.perplexity.ai/api-reference/responses-post"""
|
||||
return [
|
||||
"max_output_tokens",
|
||||
"stream",
|
||||
@@ -56,200 +33,45 @@ class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
|
||||
"top_p",
|
||||
"tools",
|
||||
"reasoning",
|
||||
"preset",
|
||||
"instructions",
|
||||
"models", # Model fallback support
|
||||
"tool_choice",
|
||||
"parallel_tool_calls",
|
||||
"max_tool_calls",
|
||||
"text",
|
||||
"previous_response_id",
|
||||
"store",
|
||||
"background",
|
||||
"truncation",
|
||||
"metadata",
|
||||
"safety_identifier",
|
||||
"user",
|
||||
"stream_options",
|
||||
"top_logprobs",
|
||||
"prompt_cache_key",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"service_tier",
|
||||
"models",
|
||||
]
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.PERPLEXITY
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""Validate environment and set up headers"""
|
||||
# Get API key from environment
|
||||
api_key = get_secret_str("PERPLEXITYAI_API_KEY") or get_secret_str(
|
||||
"PERPLEXITY_API_KEY"
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or get_secret_str("PERPLEXITYAI_API_KEY")
|
||||
or get_secret_str("PERPLEXITY_API_KEY")
|
||||
)
|
||||
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""Get the complete URL for the Perplexity Responses API"""
|
||||
if api_base is None:
|
||||
api_base = (
|
||||
get_secret_str("PERPLEXITY_API_BASE") or "https://api.perplexity.ai"
|
||||
)
|
||||
def get_complete_url(self, api_base: Optional[str], litellm_params: dict) -> str:
|
||||
api_base = api_base or get_secret_str("PERPLEXITY_API_BASE") or "https://api.perplexity.ai"
|
||||
return f"{api_base.rstrip('/')}/v1/responses"
|
||||
|
||||
# Ensure api_base doesn't end with a slash
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Add the responses endpoint
|
||||
return f"{api_base}/v1/responses"
|
||||
|
||||
def map_openai_params( # noqa: PLR0915
|
||||
self,
|
||||
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map OpenAI Responses API parameters to Perplexity format
|
||||
|
||||
Key differences:
|
||||
- Supports 'preset' parameter for predefined configurations
|
||||
- Supports 'instructions' parameter for system-level guidance
|
||||
- Tools are specified differently (web_search, fetch_url)
|
||||
"""
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
# Map standard parameters
|
||||
if response_api_optional_params.get("max_output_tokens"):
|
||||
mapped_params["max_output_tokens"] = response_api_optional_params[
|
||||
"max_output_tokens"
|
||||
]
|
||||
|
||||
if response_api_optional_params.get("temperature"):
|
||||
mapped_params["temperature"] = response_api_optional_params["temperature"]
|
||||
|
||||
if response_api_optional_params.get("top_p"):
|
||||
mapped_params["top_p"] = response_api_optional_params["top_p"]
|
||||
|
||||
if response_api_optional_params.get("stream"):
|
||||
mapped_params["stream"] = response_api_optional_params["stream"]
|
||||
|
||||
if response_api_optional_params.get("stream_options"):
|
||||
mapped_params["stream_options"] = response_api_optional_params[
|
||||
"stream_options"
|
||||
]
|
||||
|
||||
# Map Perplexity-specific parameters (using .get() with Any dict access)
|
||||
preset = response_api_optional_params.get("preset") # type: ignore
|
||||
if preset:
|
||||
mapped_params["preset"] = preset
|
||||
|
||||
instructions = response_api_optional_params.get("instructions") # type: ignore
|
||||
if instructions:
|
||||
mapped_params["instructions"] = instructions
|
||||
|
||||
if response_api_optional_params.get("reasoning"):
|
||||
mapped_params["reasoning"] = response_api_optional_params["reasoning"]
|
||||
|
||||
tools = response_api_optional_params.get("tools")
|
||||
if tools:
|
||||
# Convert tools to list of dicts for transformation
|
||||
tools_list = [dict(tool) if hasattr(tool, "__dict__") else tool for tool in tools] # type: ignore
|
||||
mapped_params["tools"] = self._transform_tools(tools_list) # type: ignore
|
||||
|
||||
# Tool control
|
||||
if response_api_optional_params.get("tool_choice"):
|
||||
mapped_params["tool_choice"] = response_api_optional_params["tool_choice"]
|
||||
if response_api_optional_params.get("parallel_tool_calls") is not None:
|
||||
mapped_params["parallel_tool_calls"] = response_api_optional_params[
|
||||
"parallel_tool_calls"
|
||||
]
|
||||
if response_api_optional_params.get("max_tool_calls"):
|
||||
mapped_params["max_tool_calls"] = response_api_optional_params[
|
||||
"max_tool_calls"
|
||||
]
|
||||
|
||||
# Structured outputs
|
||||
text_param = response_api_optional_params.get("text")
|
||||
if text_param:
|
||||
mapped_params["text"] = text_param
|
||||
|
||||
# Conversation continuity
|
||||
if response_api_optional_params.get("previous_response_id"):
|
||||
mapped_params["previous_response_id"] = response_api_optional_params[
|
||||
"previous_response_id"
|
||||
]
|
||||
|
||||
# Storage and lifecycle
|
||||
if response_api_optional_params.get("store") is not None:
|
||||
mapped_params["store"] = response_api_optional_params["store"]
|
||||
if response_api_optional_params.get("background") is not None:
|
||||
mapped_params["background"] = response_api_optional_params["background"]
|
||||
if response_api_optional_params.get("truncation"):
|
||||
mapped_params["truncation"] = response_api_optional_params["truncation"]
|
||||
|
||||
# Metadata
|
||||
if response_api_optional_params.get("metadata"):
|
||||
mapped_params["metadata"] = response_api_optional_params["metadata"]
|
||||
if response_api_optional_params.get("safety_identifier"):
|
||||
mapped_params["safety_identifier"] = response_api_optional_params[
|
||||
"safety_identifier"
|
||||
]
|
||||
if response_api_optional_params.get("user"):
|
||||
mapped_params["user"] = response_api_optional_params["user"]
|
||||
|
||||
# Additional
|
||||
if response_api_optional_params.get("top_logprobs") is not None:
|
||||
mapped_params["top_logprobs"] = response_api_optional_params["top_logprobs"]
|
||||
if response_api_optional_params.get("prompt_cache_key"):
|
||||
mapped_params["prompt_cache_key"] = response_api_optional_params[
|
||||
"prompt_cache_key"
|
||||
]
|
||||
if response_api_optional_params.get("frequency_penalty") is not None:
|
||||
mapped_params["frequency_penalty"] = response_api_optional_params[
|
||||
"frequency_penalty" # type: ignore[typeddict-item]
|
||||
]
|
||||
if response_api_optional_params.get("presence_penalty") is not None:
|
||||
mapped_params["presence_penalty"] = response_api_optional_params[
|
||||
"presence_penalty" # type: ignore[typeddict-item]
|
||||
]
|
||||
if response_api_optional_params.get("service_tier"):
|
||||
mapped_params["service_tier"] = response_api_optional_params["service_tier"]
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _transform_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transform tools to Perplexity format.
|
||||
|
||||
Perplexity supports (per public OpenAPI spec):
|
||||
- web_search: Performs web searches
|
||||
- fetch_url: Fetches content from URLs
|
||||
- function: Function Calling
|
||||
"""
|
||||
perplexity_tools = []
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
tool_type = tool.get("type", "")
|
||||
|
||||
# Direct Perplexity tool format
|
||||
if tool_type in ["web_search", "fetch_url"]:
|
||||
perplexity_tools.append(tool)
|
||||
|
||||
# Function tools: Perplexity supports them natively
|
||||
elif tool_type == "function":
|
||||
perplexity_tools.append(tool)
|
||||
|
||||
return perplexity_tools
|
||||
def _ensure_message_type(
|
||||
self, input: Union[str, ResponseInputParam]
|
||||
) -> Union[str, List[Dict[str, Any]]]:
|
||||
"""Ensure list input items have type='message' (required by Perplexity)."""
|
||||
if isinstance(input, str):
|
||||
return input
|
||||
if isinstance(input, list):
|
||||
result = []
|
||||
for item in input:
|
||||
if isinstance(item, dict) and "type" not in item:
|
||||
item = {**item, "type": "message"}
|
||||
result.append(item)
|
||||
return result
|
||||
return input
|
||||
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
@@ -259,62 +81,23 @@ class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""
|
||||
Transform request to Perplexity Responses API format
|
||||
"""
|
||||
# Check if the model is a preset (format: preset/preset-name)
|
||||
"""Handle preset/ model prefix: send as {"preset": name} instead of {"model": name}."""
|
||||
input = self._ensure_message_type(input)
|
||||
if model.startswith("preset/"):
|
||||
preset_name = model.replace("preset/", "")
|
||||
data = {
|
||||
"preset": preset_name,
|
||||
"input": self._format_input(input),
|
||||
input = self._validate_input_param(input)
|
||||
data: Dict = {
|
||||
"preset": model[len("preset/"):],
|
||||
"input": input,
|
||||
}
|
||||
# Check if preset is explicitly provided in params
|
||||
elif response_api_optional_request_params.get("preset"):
|
||||
data = {
|
||||
"preset": response_api_optional_request_params.pop("preset"),
|
||||
"input": self._format_input(input),
|
||||
}
|
||||
else:
|
||||
# Full request format for third-party models
|
||||
data = {
|
||||
"model": model,
|
||||
"input": self._format_input(input),
|
||||
}
|
||||
|
||||
# Add all optional parameters
|
||||
for key, value in response_api_optional_request_params.items():
|
||||
data[key] = value
|
||||
|
||||
return data
|
||||
|
||||
def _format_input(
|
||||
self, input: Union[str, ResponseInputParam]
|
||||
) -> Union[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Format input for Perplexity Responses API
|
||||
|
||||
The API accepts either:
|
||||
- A simple string for single-turn queries
|
||||
- An array of message objects for multi-turn conversations
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
return input
|
||||
|
||||
# Handle ResponseInputParam format
|
||||
if isinstance(input, list):
|
||||
formatted_messages = []
|
||||
for item in input:
|
||||
if isinstance(item, dict):
|
||||
formatted_message = {
|
||||
"type": "message",
|
||||
"role": item.get("role"),
|
||||
"content": item.get("content", ""),
|
||||
}
|
||||
formatted_messages.append(formatted_message)
|
||||
return formatted_messages
|
||||
|
||||
return str(input)
|
||||
data.update(response_api_optional_request_params)
|
||||
return data
|
||||
return super().transform_responses_api_request(
|
||||
model=model,
|
||||
input=input,
|
||||
response_api_optional_request_params=response_api_optional_request_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def transform_response_api_response(
|
||||
self,
|
||||
@@ -322,174 +105,27 @@ class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
"""
|
||||
Transform Perplexity Responses API response to OpenAI Responses API format
|
||||
"""
|
||||
"""Check for Perplexity's status:'failed' on HTTP 200 before delegating to base."""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception as e:
|
||||
raise BaseLLMException(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Failed to parse response: {str(e)}",
|
||||
)
|
||||
|
||||
# Check for error status
|
||||
status = raw_response_json.get("status")
|
||||
if status == "failed":
|
||||
error = raw_response_json.get("error", {})
|
||||
error_message = error.get("message", "Unknown error")
|
||||
raise BaseLLMException(
|
||||
status_code=raw_response.status_code,
|
||||
message=error_message,
|
||||
)
|
||||
|
||||
# Transform usage to handle Perplexity's cost structure
|
||||
usage_data = raw_response_json.get("usage", {})
|
||||
transformed_usage_dict = self._transform_usage(usage_data)
|
||||
|
||||
# Convert usage dict to ResponseAPIUsage object
|
||||
usage_obj = (
|
||||
ResponseAPIUsage(**transformed_usage_dict)
|
||||
if transformed_usage_dict
|
||||
else None
|
||||
)
|
||||
|
||||
# Map Perplexity response to OpenAI Responses API format
|
||||
response = ResponsesAPIResponse(
|
||||
id=raw_response_json.get("id", ""),
|
||||
object="response",
|
||||
created_at=raw_response_json.get("created_at", 0),
|
||||
status=raw_response_json.get("status", "completed"),
|
||||
model=raw_response_json.get("model", model),
|
||||
output=raw_response_json.get("output", []),
|
||||
usage=usage_obj,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _transform_usage(self, usage_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform Perplexity usage data to OpenAI format
|
||||
|
||||
Perplexity returns:
|
||||
{
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 200,
|
||||
"total_tokens": 300,
|
||||
"cost": {
|
||||
"currency": "USD",
|
||||
"input_cost": 0.0001,
|
||||
"output_cost": 0.0002,
|
||||
"total_cost": 0.0003
|
||||
}
|
||||
}
|
||||
|
||||
OpenAI expects:
|
||||
{
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 200,
|
||||
"total_tokens": 300,
|
||||
"cost": 0.0003
|
||||
}
|
||||
"""
|
||||
transformed = {
|
||||
"input_tokens": usage_data.get("input_tokens", 0),
|
||||
"output_tokens": usage_data.get("output_tokens", 0),
|
||||
"total_tokens": usage_data.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
# Transform cost from Perplexity format (dict) to OpenAI format (float)
|
||||
cost_obj = usage_data.get("cost")
|
||||
if isinstance(cost_obj, dict) and "total_cost" in cost_obj:
|
||||
transformed["cost"] = cost_obj["total_cost"]
|
||||
verbose_logger.debug(
|
||||
"Transformed Perplexity cost object to float: %s -> %s",
|
||||
cost_obj,
|
||||
cost_obj["total_cost"],
|
||||
)
|
||||
elif cost_obj is not None:
|
||||
# If cost is already a float/number, use it as-is
|
||||
transformed["cost"] = cost_obj
|
||||
|
||||
# Add input_tokens_details if present
|
||||
if "input_tokens_details" in usage_data:
|
||||
transformed["input_tokens_details"] = usage_data["input_tokens_details"]
|
||||
|
||||
# Add output_tokens_details if present
|
||||
if "output_tokens_details" in usage_data:
|
||||
transformed["output_tokens_details"] = usage_data["output_tokens_details"]
|
||||
|
||||
return transformed
|
||||
|
||||
def transform_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
parsed_chunk: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Transform a parsed streaming response chunk into a ResponsesAPIStreamingResponse
|
||||
"""
|
||||
# Get the event type from the chunk
|
||||
verbose_logger.debug("Raw Perplexity Chunk=%s", parsed_chunk)
|
||||
event_type = str(parsed_chunk.get("type"))
|
||||
event_pydantic_model = PerplexityResponsesConfig.get_event_model_class(
|
||||
event_type=event_type
|
||||
)
|
||||
|
||||
# Transform Perplexity-specific fields to OpenAI format
|
||||
parsed_chunk = self._transform_perplexity_chunk(parsed_chunk)
|
||||
|
||||
# Defensive: Handle error.code being null (similar to OpenAI implementation)
|
||||
try:
|
||||
error_obj = parsed_chunk.get("error")
|
||||
if isinstance(error_obj, dict) and error_obj.get("code") is None:
|
||||
# Preserve other fields, but ensure `code` is a non-null string
|
||||
parsed_chunk = dict(parsed_chunk)
|
||||
parsed_chunk["error"] = dict(error_obj)
|
||||
parsed_chunk["error"]["code"] = "unknown_error"
|
||||
except Exception:
|
||||
# If anything unexpected happens here, fall back to attempting
|
||||
# instantiation and let higher-level handlers manage errors.
|
||||
verbose_logger.debug("Failed to coalesce error.code in parsed_chunk")
|
||||
raw_response_json = None
|
||||
|
||||
return event_pydantic_model(**parsed_chunk)
|
||||
if (
|
||||
isinstance(raw_response_json, dict)
|
||||
and raw_response_json.get("status") == "failed"
|
||||
):
|
||||
error = raw_response_json.get("error", {})
|
||||
raise BaseLLMException(
|
||||
status_code=raw_response.status_code,
|
||||
message=error.get("message", "Unknown Perplexity error"),
|
||||
)
|
||||
|
||||
def _transform_perplexity_chunk(self, chunk: dict) -> dict:
|
||||
"""
|
||||
Transform Perplexity-specific fields in a streaming chunk to OpenAI format.
|
||||
|
||||
This handles:
|
||||
- Converting Perplexity's cost object to a simple float
|
||||
"""
|
||||
# Make a copy to avoid modifying the original
|
||||
chunk = dict(chunk)
|
||||
|
||||
# Transform usage.cost from Perplexity format to OpenAI format
|
||||
# Perplexity: {"currency": "USD", "input_cost": 0.0001, "output_cost": 0.0002, "total_cost": 0.0003}
|
||||
# OpenAI: 0.0003 (just the total_cost as a float)
|
||||
try:
|
||||
response_obj = chunk.get("response")
|
||||
if isinstance(response_obj, dict):
|
||||
usage_obj = response_obj.get("usage")
|
||||
if isinstance(usage_obj, dict):
|
||||
cost_obj = usage_obj.get("cost")
|
||||
if isinstance(cost_obj, dict) and "total_cost" in cost_obj:
|
||||
# Replace the cost object with just the total_cost value
|
||||
chunk = dict(chunk)
|
||||
chunk["response"] = dict(response_obj)
|
||||
chunk["response"]["usage"] = dict(usage_obj)
|
||||
chunk["response"]["usage"]["cost"] = cost_obj["total_cost"]
|
||||
verbose_logger.debug(
|
||||
"Transformed Perplexity cost object to float: %s -> %s",
|
||||
cost_obj,
|
||||
cost_obj["total_cost"],
|
||||
)
|
||||
except Exception as e:
|
||||
# If transformation fails, log and continue with original chunk
|
||||
verbose_logger.debug("Failed to transform Perplexity cost object: %s", e)
|
||||
|
||||
return chunk
|
||||
return super().transform_response_api_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
def supports_native_websocket(self) -> bool:
|
||||
"""Perplexity does not support native WebSocket for Responses API"""
|
||||
|
||||
@@ -583,17 +583,35 @@ class SagemakerLLM(BaseAWSLLM):
|
||||
### BOTO3 INIT
|
||||
import boto3
|
||||
|
||||
# Use _load_credentials to support role assumption (aws_role_name, aws_session_name)
|
||||
credentials, aws_region_name = self._load_credentials(optional_params)
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
# Create boto3 session with the loaded credentials
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials.access_key,
|
||||
aws_secret_access_key=credentials.secret_key,
|
||||
aws_session_token=credentials.token,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
client = session.client(service_name="sagemaker-runtime")
|
||||
if aws_access_key_id is not None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automaticaly reads env variables
|
||||
|
||||
# we need to read region name from env
|
||||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME")
|
||||
or aws_region_name # get region from config file if specified
|
||||
or "us-west-2" # default to us-west-2 if region not specified
|
||||
)
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
inference_params = deepcopy(optional_params)
|
||||
@@ -610,9 +628,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||
#### EMBEDDING LOGIC
|
||||
# Transform request based on model type
|
||||
provider_config = SagemakerEmbeddingConfig.get_model_config(model)
|
||||
request_data = provider_config.transform_embedding_request(
|
||||
model, input, optional_params, {}
|
||||
)
|
||||
request_data = provider_config.transform_embedding_request(model, input, optional_params, {})
|
||||
data = json.dumps(request_data).encode("utf-8")
|
||||
|
||||
## LOGGING
|
||||
@@ -657,19 +673,19 @@ class SagemakerLLM(BaseAWSLLM):
|
||||
)
|
||||
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
|
||||
|
||||
# Transform response based on model type
|
||||
from httpx import Response as HttpxResponse
|
||||
|
||||
|
||||
# Create a mock httpx Response object for the transformation
|
||||
mock_response = HttpxResponse(
|
||||
status_code=200,
|
||||
content=json.dumps(response).encode("utf-8"),
|
||||
headers={"content-type": "application/json"},
|
||||
content=json.dumps(response).encode('utf-8'),
|
||||
headers={"content-type": "application/json"}
|
||||
)
|
||||
|
||||
|
||||
model_response = EmbeddingResponse()
|
||||
|
||||
|
||||
# Use the request_data that was already transformed above
|
||||
return provider_config.transform_embedding_response(
|
||||
model=model,
|
||||
@@ -679,5 +695,5 @@ class SagemakerLLM(BaseAWSLLM):
|
||||
api_key=None,
|
||||
request_data=request_data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params or {},
|
||||
litellm_params=litellm_params or {}
|
||||
)
|
||||
|
||||
@@ -208,27 +208,28 @@ class SnowflakeConfig(SnowflakeBaseConfig, OpenAIGPTConfig):
|
||||
|
||||
def _transform_tool_choice(
|
||||
self, tool_choice: Union[str, Dict[str, Any]]
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform OpenAI tool_choice format to Snowflake format.
|
||||
|
||||
Snowflake requires tool_choice to be an object, not a string.
|
||||
Ref: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/reference/cortex-inference#post--api-v2-cortex-inference-complete-req-body-schema
|
||||
|
||||
Args:
|
||||
tool_choice: Tool choice in OpenAI format (str or dict)
|
||||
|
||||
Returns:
|
||||
Tool choice in Snowflake format
|
||||
Tool choice in Snowflake format (always an object)
|
||||
|
||||
OpenAI format:
|
||||
{"type": "function", "function": {"name": "get_weather"}}
|
||||
OpenAI format (string): "auto", "required", "none"
|
||||
OpenAI format (object): {"type": "function", "function": {"name": "get_weather"}}
|
||||
|
||||
Snowflake format:
|
||||
{"type": "tool", "name": ["get_weather"]}
|
||||
|
||||
Note: String values ("auto", "required", "none") pass through unchanged.
|
||||
Snowflake format (string values become objects): {"type": "auto"}
|
||||
Snowflake format (specific tool): {"type": "tool", "name": ["get_weather"]}
|
||||
"""
|
||||
if isinstance(tool_choice, str):
|
||||
# "auto", "required", "none" pass through as-is
|
||||
return tool_choice
|
||||
# Snowflake requires object format: {"type": "auto"} not string "auto"
|
||||
return {"type": tool_choice}
|
||||
|
||||
if isinstance(tool_choice, dict):
|
||||
if tool_choice.get("type") == "function":
|
||||
|
||||
@@ -249,23 +249,27 @@ def _get_embedding_url(
|
||||
- bge/endpoint_id -> strips to endpoint_id for endpoints/ routing
|
||||
- numeric model -> routes to endpoints/
|
||||
- regular model -> routes to publishers/google/models/
|
||||
- models with uses_embed_content flag -> use embedContent endpoint instead of predict
|
||||
"""
|
||||
endpoint = "predict"
|
||||
|
||||
# Strip routing prefixes (bge/, gemma/, etc.) for endpoint URL construction
|
||||
original_model = model
|
||||
model = get_vertex_base_model_name(model=model)
|
||||
|
||||
# Get base URL (handles global vs regional)
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=original_model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
uses_embed_content = model_info.get("uses_embed_content", False)
|
||||
except Exception:
|
||||
uses_embed_content = False
|
||||
|
||||
endpoint = "embedContent" if uses_embed_content else "predict"
|
||||
|
||||
base_url = get_vertex_base_url(vertex_location)
|
||||
|
||||
if model.isdigit():
|
||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
|
||||
# https://aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/global/endpoints/$ENDPOINT_ID:predict
|
||||
url = f"{base_url}/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||
else:
|
||||
# Regular model -> publisher model
|
||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/publishers/google/models/{model}:predict
|
||||
# https://aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/global/publishers/google/models/{model}:predict
|
||||
url = f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
|
||||
return url, endpoint
|
||||
@@ -518,6 +522,29 @@ def _build_vertex_schema(parameters: dict, add_property_ordering: bool = False):
|
||||
return parameters
|
||||
|
||||
|
||||
def _build_vertex_schema_for_gemini_2(parameters: dict) -> dict:
|
||||
"""
|
||||
Minimal schema builder for Gemini 2.0+ tool parameters.
|
||||
|
||||
Gemini 2.0+ accepts standard JSON Schema natively in tool parameters,
|
||||
including lowercase types, anyOf with null, and bare {} (TYPE_UNSPECIFIED).
|
||||
The only transformation needed is resolving $ref/$defs, which Gemini does
|
||||
NOT support in tool parameters (returns 400).
|
||||
|
||||
This avoids the harmful transforms in _build_vertex_schema that break
|
||||
JsonValue/Any semantics by coercing {} to {"type": "object"}.
|
||||
"""
|
||||
valid_schema_fields = set(get_type_hints(Schema).keys())
|
||||
|
||||
parameters = dict(parameters) # shallow copy to avoid mutating caller's dict
|
||||
defs = parameters.pop("$defs", {})
|
||||
unpack_defs(parameters, defs)
|
||||
|
||||
parameters = filter_schema_fields(parameters, valid_schema_fields)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def _build_json_schema(parameters: dict) -> dict:
|
||||
"""
|
||||
Build a JSON Schema for use with Gemini's responseJsonSchema parameter.
|
||||
|
||||
@@ -4,13 +4,16 @@ import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.caching.caching import Cache, LiteLLMCacheType
|
||||
from litellm.constants import MINIMUM_PROMPT_CACHE_TOKEN_COUNT
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.openai.openai import AllMessageValues
|
||||
from litellm.utils import is_prompt_caching_valid_prompt
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
CachedContentListAllResponseBody,
|
||||
VertexAICachedContentResponseObject,
|
||||
@@ -315,6 +318,20 @@ class ContextCachingEndpoints(VertexBase):
|
||||
if len(cached_messages) == 0:
|
||||
return messages, optional_params, None
|
||||
|
||||
# Gemini requires a minimum of 1024 tokens for context caching.
|
||||
# Skip caching if the cached content is too small to avoid API errors.
|
||||
if not is_prompt_caching_valid_prompt(
|
||||
model=model,
|
||||
messages=cached_messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"Vertex AI context caching: cached content is below minimum token "
|
||||
"count (%d). Skipping context caching.",
|
||||
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
|
||||
)
|
||||
return messages, optional_params, None
|
||||
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
||||
## AUTHORIZATION ##
|
||||
@@ -447,6 +464,20 @@ class ContextCachingEndpoints(VertexBase):
|
||||
if len(cached_messages) == 0:
|
||||
return messages, optional_params, None
|
||||
|
||||
# Gemini requires a minimum of 1024 tokens for context caching.
|
||||
# Skip caching if the cached content is too small to avoid API errors.
|
||||
if not is_prompt_caching_valid_prompt(
|
||||
model=model,
|
||||
messages=cached_messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"Vertex AI context caching: cached content is below minimum token "
|
||||
"count (%d). Skipping context caching.",
|
||||
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
|
||||
)
|
||||
return messages, optional_params, None
|
||||
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
||||
## AUTHORIZATION ##
|
||||
|
||||
@@ -77,6 +77,60 @@ def _convert_detail_to_media_resolution_enum(
|
||||
return None
|
||||
|
||||
|
||||
def _get_highest_media_resolution(
|
||||
current: Optional[str], new_detail: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Compare two media resolution values and return the highest one.
|
||||
Resolution hierarchy: ultra_high > high > medium > low > None
|
||||
"""
|
||||
resolution_priority = {"ultra_high": 4, "high": 3, "medium": 2, "low": 1}
|
||||
current_priority = resolution_priority.get(current, 0) if current else 0
|
||||
new_priority = resolution_priority.get(new_detail, 0) if new_detail else 0
|
||||
|
||||
if new_priority > current_priority:
|
||||
return new_detail
|
||||
return current
|
||||
|
||||
|
||||
def _extract_max_media_resolution_from_messages(
|
||||
messages: List[AllMessageValues],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract the highest media resolution (detail) from image content in messages.
|
||||
|
||||
This is used to set the global media_resolution in generation_config for
|
||||
Gemini 2.x models which don't support per-part media resolution.
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format
|
||||
|
||||
Returns:
|
||||
The highest detail level found ("high", "low", or None)
|
||||
"""
|
||||
max_resolution: Optional[str] = None
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
detail: Optional[str] = None
|
||||
if item.get("type") == "image_url":
|
||||
image_url = item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
detail = image_url.get("detail")
|
||||
elif item.get("type") == "file":
|
||||
file_obj = item.get("file")
|
||||
if isinstance(file_obj, dict):
|
||||
detail = file_obj.get("detail")
|
||||
if detail:
|
||||
max_resolution = _get_highest_media_resolution(
|
||||
max_resolution, detail
|
||||
)
|
||||
return max_resolution
|
||||
|
||||
|
||||
def _apply_gemini_3_metadata(
|
||||
part: PartType,
|
||||
model: Optional[str],
|
||||
@@ -539,10 +593,6 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
||||
raise e
|
||||
|
||||
|
||||
# Keys that LiteLLM consumes internally and must never be forwarded to the
|
||||
_LITELLM_INTERNAL_EXTRA_BODY_KEYS: frozenset = frozenset({"cache", "tags"})
|
||||
|
||||
|
||||
def _pop_and_merge_extra_body(data: RequestBody, optional_params: dict) -> None:
|
||||
"""Pop extra_body from optional_params and shallow-merge into data, deep-merging dict values."""
|
||||
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
|
||||
@@ -561,7 +611,7 @@ def _pop_and_merge_extra_body(data: RequestBody, optional_params: dict) -> None:
|
||||
data_dict[k] = v
|
||||
|
||||
|
||||
def _transform_request_body(
|
||||
def _transform_request_body( # noqa: PLR0915
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
@@ -639,6 +689,19 @@ def _transform_request_body(
|
||||
generation_config: Optional[GenerationConfig] = GenerationConfig(
|
||||
**filtered_params
|
||||
)
|
||||
|
||||
# For Gemini 2.x models, add media_resolution to generation_config (global)
|
||||
# Gemini 3+ supports per-part media_resolution, but 2.x only supports global
|
||||
# Gemini 1.x does not support mediaResolution at all
|
||||
if "gemini-2" in model:
|
||||
max_media_resolution = _extract_max_media_resolution_from_messages(messages)
|
||||
if max_media_resolution:
|
||||
media_resolution_value = _convert_detail_to_media_resolution_enum(
|
||||
max_media_resolution
|
||||
)
|
||||
if media_resolution_value and generation_config is not None:
|
||||
generation_config["mediaResolution"] = media_resolution_value["level"]
|
||||
|
||||
data = RequestBody(contents=content)
|
||||
if system_instructions is not None:
|
||||
data["system_instruction"] = system_instructions
|
||||
|
||||
@@ -97,6 +97,7 @@ from ..common_utils import (
|
||||
VertexAIError,
|
||||
_build_json_schema,
|
||||
_build_vertex_schema,
|
||||
_build_vertex_schema_for_gemini_2,
|
||||
supports_response_json_schema,
|
||||
)
|
||||
from ..vertex_llm_base import VertexBase
|
||||
@@ -467,7 +468,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||
return None
|
||||
|
||||
def _map_function( # noqa: PLR0915
|
||||
self, value: List[dict], optional_params: dict
|
||||
self, value: List[dict], optional_params: dict, model: str = ""
|
||||
) -> List[Tools]:
|
||||
"""
|
||||
Map OpenAI-style tools/functions to Vertex AI format.
|
||||
@@ -510,10 +511,21 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||
"parameters" in _openai_function_object
|
||||
and _openai_function_object["parameters"] is not None
|
||||
and isinstance(_openai_function_object["parameters"], dict)
|
||||
): # OPENAI accepts JSON Schema, Google accepts OpenAPI schema.
|
||||
_openai_function_object["parameters"] = _build_vertex_schema(
|
||||
_openai_function_object["parameters"]
|
||||
)
|
||||
):
|
||||
if supports_response_json_schema(model):
|
||||
# Gemini 2.0+: minimal transform (resolve $ref only)
|
||||
_openai_function_object["parameters"] = (
|
||||
_build_vertex_schema_for_gemini_2(
|
||||
_openai_function_object["parameters"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Gemini 1.5: full OpenAPI-style transform
|
||||
_openai_function_object["parameters"] = (
|
||||
_build_vertex_schema(
|
||||
_openai_function_object["parameters"]
|
||||
)
|
||||
)
|
||||
|
||||
openai_function_object = _openai_function_object
|
||||
|
||||
@@ -1048,7 +1060,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||
):
|
||||
# Pass optional_params so _map_function can add toolConfig if needed
|
||||
mapped_tools = self._map_function(
|
||||
value=value, optional_params=optional_params
|
||||
value=value, optional_params=optional_params, model=model
|
||||
)
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params, mapped_tools
|
||||
@@ -1227,27 +1239,25 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||
"IMAGE_PROHIBITED_CONTENT": "The token generation was stopped as the response was flagged for prohibited image content.",
|
||||
}
|
||||
|
||||
_GEMINI_FINISH_REASON_KEYS = frozenset({
|
||||
"STOP", "MAX_TOKENS", "SAFETY", "RECITATION", "FINISH_REASON_UNSPECIFIED",
|
||||
"MALFORMED_FUNCTION_CALL", "LANGUAGE", "OTHER", "BLOCKLIST",
|
||||
"PROHIBITED_CONTENT", "SPII", "IMAGE_SAFETY", "IMAGE_PROHIBITED_CONTENT",
|
||||
"TOO_MANY_TOOL_CALLS", "MALFORMED_RESPONSE",
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def get_finish_reason_mapping() -> Dict[str, OpenAIChatCompletionFinishReason]:
|
||||
"""
|
||||
Return Dictionary of finish reasons which indicate response was flagged
|
||||
|
||||
and what it means
|
||||
Return Dictionary of Gemini/Vertex AI finish reasons and their
|
||||
OpenAI-compatible mappings.
|
||||
"""
|
||||
from litellm.litellm_core_utils.core_helpers import _FINISH_REASON_MAP
|
||||
|
||||
return {
|
||||
"FINISH_REASON_UNSPECIFIED": "finish_reason_unspecified",
|
||||
"STOP": "stop",
|
||||
"MAX_TOKENS": "length",
|
||||
"SAFETY": "content_filter",
|
||||
"RECITATION": "content_filter",
|
||||
"LANGUAGE": "content_filter",
|
||||
"OTHER": "content_filter",
|
||||
"BLOCKLIST": "content_filter",
|
||||
"PROHIBITED_CONTENT": "content_filter",
|
||||
"SPII": "content_filter",
|
||||
"MALFORMED_FUNCTION_CALL": "malformed_function_call", # openai doesn't have a way of representing this
|
||||
"IMAGE_SAFETY": "content_filter",
|
||||
"IMAGE_PROHIBITED_CONTENT": "content_filter",
|
||||
k: v
|
||||
for k, v in _FINISH_REASON_MAP.items()
|
||||
if k in VertexGeminiConfig._GEMINI_FINISH_REASON_KEYS
|
||||
}
|
||||
|
||||
def translate_exception_str(self, exception_string: str):
|
||||
@@ -1766,15 +1776,14 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||
chat_completion_message: Optional[ChatCompletionResponseMessage],
|
||||
finish_reason: Optional[str],
|
||||
) -> OpenAIChatCompletionFinishReason:
|
||||
mapped_finish_reason = VertexGeminiConfig.get_finish_reason_mapping()
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
|
||||
if chat_completion_message and chat_completion_message.get("function_call"):
|
||||
return "function_call"
|
||||
elif chat_completion_message and chat_completion_message.get("tool_calls"):
|
||||
return "tool_calls"
|
||||
elif (
|
||||
finish_reason and finish_reason in mapped_finish_reason.keys()
|
||||
): # vertex ai
|
||||
return mapped_finish_reason[finish_reason]
|
||||
elif finish_reason:
|
||||
return map_finish_reason(finish_reason)
|
||||
else:
|
||||
return "stop"
|
||||
|
||||
@@ -2362,7 +2371,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
client: Optional[AsyncHTTPHandler], # module-level client
|
||||
gemini_client: Optional[AsyncHTTPHandler], # if passed by user
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
@@ -2370,6 +2380,8 @@ async def make_call(
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
if gemini_client is not None:
|
||||
client = gemini_client
|
||||
if client is None:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
@@ -2541,7 +2553,11 @@ class VertexLLM(VertexBase):
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
client=client,
|
||||
gemini_client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=request_body_str,
|
||||
|
||||
@@ -3,12 +3,11 @@ Google AI Studio /batchEmbedContents Embeddings Endpoint
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
@@ -19,15 +18,98 @@ from litellm.types.llms.vertex_ai import (
|
||||
VertexAIBatchEmbeddingsRequestBody,
|
||||
VertexAIBatchEmbeddingsResponseObject,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from .batch_embed_content_transformation import (
|
||||
_is_file_reference,
|
||||
_is_multimodal_input,
|
||||
process_embed_content_response,
|
||||
process_response,
|
||||
transform_openai_input_gemini_content,
|
||||
transform_openai_input_gemini_embed_content,
|
||||
)
|
||||
|
||||
|
||||
class GoogleBatchEmbeddings(VertexLLM):
|
||||
def _resolve_file_references(
|
||||
self,
|
||||
input: EmbeddingInput,
|
||||
api_key: str,
|
||||
sync_handler: HTTPHandler,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Resolve Gemini file references (files/...) to get mime_type and uri.
|
||||
|
||||
Args:
|
||||
input: EmbeddingInput that may contain file references
|
||||
api_key: Gemini API key
|
||||
sync_handler: HTTP client
|
||||
|
||||
Returns:
|
||||
Dict mapping file name to {mime_type, uri}
|
||||
"""
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
resolved_files: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
for element in input_list:
|
||||
if isinstance(element, str) and _is_file_reference(element):
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/{element}"
|
||||
headers = {"x-goog-api-key": api_key}
|
||||
response = sync_handler.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error fetching file {element}: {response.status_code} {response.text}"
|
||||
)
|
||||
|
||||
file_data = response.json()
|
||||
resolved_files[element] = {
|
||||
"mime_type": file_data.get("mimeType", ""),
|
||||
"uri": file_data.get("uri", element),
|
||||
}
|
||||
|
||||
return resolved_files
|
||||
|
||||
async def _async_resolve_file_references(
|
||||
self,
|
||||
input: EmbeddingInput,
|
||||
api_key: str,
|
||||
async_handler: AsyncHTTPHandler,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Async version of _resolve_file_references.
|
||||
|
||||
Args:
|
||||
input: EmbeddingInput that may contain file references
|
||||
api_key: Gemini API key
|
||||
async_handler: Async HTTP client
|
||||
|
||||
Returns:
|
||||
Dict mapping file name to {mime_type, uri}
|
||||
"""
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
resolved_files: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
for element in input_list:
|
||||
if isinstance(element, str) and _is_file_reference(element):
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/{element}"
|
||||
headers = {"x-goog-api-key": api_key}
|
||||
response = await async_handler.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error fetching file {element}: {response.status_code} {response.text}"
|
||||
)
|
||||
|
||||
file_data = response.json()
|
||||
resolved_files[element] = {
|
||||
"mime_type": file_data.get("mimeType", ""),
|
||||
"uri": file_data.get("uri", element),
|
||||
}
|
||||
|
||||
return resolved_files
|
||||
|
||||
def batch_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
@@ -54,20 +136,6 @@ class GoogleBatchEmbeddings(VertexLLM):
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode="batch_embedding",
|
||||
)
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
@@ -83,9 +151,25 @@ class GoogleBatchEmbeddings(VertexLLM):
|
||||
|
||||
optional_params = optional_params or {}
|
||||
|
||||
### TRANSFORMATION ###
|
||||
request_data = transform_openai_input_gemini_content(
|
||||
input=input, model=model, optional_params=optional_params
|
||||
is_multimodal = _is_multimodal_input(input)
|
||||
use_embed_content = is_multimodal or (custom_llm_provider == "vertex_ai")
|
||||
if use_embed_content:
|
||||
mode = "embedding"
|
||||
else:
|
||||
mode = "batch_embedding"
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
headers = {
|
||||
@@ -93,14 +177,46 @@ class GoogleBatchEmbeddings(VertexLLM):
|
||||
}
|
||||
if auth_header is not None:
|
||||
if isinstance(auth_header, dict):
|
||||
# For Gemini with custom api_base: auth_header is {"x-goog-api-key": "..."}
|
||||
headers.update(auth_header)
|
||||
else:
|
||||
# For Vertex AI: auth_header is a Bearer token string
|
||||
headers["Authorization"] = f"Bearer {auth_header}"
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
if aembedding is True:
|
||||
return self.async_batch_embeddings( # type: ignore
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
url=url,
|
||||
data=None,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
input=input,
|
||||
use_embed_content=use_embed_content,
|
||||
api_key=api_key,
|
||||
optional_params=optional_params,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
### TRANSFORMATION (sync path) ###
|
||||
if use_embed_content:
|
||||
resolved_files = {}
|
||||
if api_key:
|
||||
resolved_files = self._resolve_file_references(
|
||||
input=input, api_key=api_key, sync_handler=sync_handler
|
||||
)
|
||||
request_data = transform_openai_input_gemini_embed_content(
|
||||
input=input,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
resolved_files=resolved_files,
|
||||
)
|
||||
else:
|
||||
request_data = transform_openai_input_gemini_content(
|
||||
input=input, model=model, optional_params=optional_params
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
@@ -112,18 +228,6 @@ class GoogleBatchEmbeddings(VertexLLM):
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding is True:
|
||||
return self.async_batch_embeddings( # type: ignore
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
url=url,
|
||||
data=request_data,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
input=input,
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
@@ -134,26 +238,38 @@ class GoogleBatchEmbeddings(VertexLLM):
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
|
||||
return process_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
_predictions=_predictions,
|
||||
input=input,
|
||||
)
|
||||
|
||||
if use_embed_content:
|
||||
return process_embed_content_response(
|
||||
input=input,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
response_json=_json_response,
|
||||
)
|
||||
else:
|
||||
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
return process_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
_predictions=_predictions,
|
||||
input=input,
|
||||
)
|
||||
|
||||
async def async_batch_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
url: str,
|
||||
data: VertexAIBatchEmbeddingsRequestBody,
|
||||
data: Optional[Union[VertexAIBatchEmbeddingsRequestBody, dict]],
|
||||
model_response: EmbeddingResponse,
|
||||
input: EmbeddingInput,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
use_embed_content: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> EmbeddingResponse:
|
||||
if client is None:
|
||||
_params = {}
|
||||
@@ -171,6 +287,36 @@ class GoogleBatchEmbeddings(VertexLLM):
|
||||
else:
|
||||
async_handler = client # type: ignore
|
||||
|
||||
### TRANSFORMATION (async path) ###
|
||||
if use_embed_content:
|
||||
resolved_files = {}
|
||||
if api_key:
|
||||
resolved_files = await self._async_resolve_file_references(
|
||||
input=input, api_key=api_key, async_handler=async_handler
|
||||
)
|
||||
data = transform_openai_input_gemini_embed_content(
|
||||
input=input,
|
||||
model=model,
|
||||
optional_params=optional_params or {},
|
||||
resolved_files=resolved_files,
|
||||
)
|
||||
else:
|
||||
data = transform_openai_input_gemini_content(
|
||||
input=input, model=model, optional_params=optional_params or {}
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = await async_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
@@ -181,11 +327,19 @@ class GoogleBatchEmbeddings(VertexLLM):
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
|
||||
return process_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
_predictions=_predictions,
|
||||
input=input,
|
||||
)
|
||||
|
||||
if use_embed_content:
|
||||
return process_embed_content_response(
|
||||
input=input,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
response_json=_json_response,
|
||||
)
|
||||
else:
|
||||
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
return process_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
_predictions=_predictions,
|
||||
input=input,
|
||||
)
|
||||
|
||||
@@ -4,20 +4,142 @@ Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batc
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
from litellm.types.llms.openai import EmbeddingInput
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
BlobType,
|
||||
ContentType,
|
||||
EmbedContentRequest,
|
||||
FileDataType,
|
||||
PartType,
|
||||
VertexAIBatchEmbeddingsRequestBody,
|
||||
VertexAIBatchEmbeddingsResponseObject,
|
||||
)
|
||||
from litellm.types.utils import Embedding, Usage
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
from litellm.utils import get_formatted_prompt, token_counter
|
||||
|
||||
SUPPORTED_EMBEDDING_MIME_TYPES = {
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
"application/pdf",
|
||||
}
|
||||
|
||||
|
||||
def _is_file_reference(s: str) -> bool:
|
||||
"""Check if string is a Gemini file reference (files/...)."""
|
||||
return isinstance(s, str) and s.startswith("files/")
|
||||
|
||||
|
||||
def _is_gcs_url(s: str) -> bool:
|
||||
"""Check if string is a GCS URL (gs://...)."""
|
||||
return isinstance(s, str) and s.startswith("gs://")
|
||||
|
||||
|
||||
def _infer_mime_type_from_gcs_url(gcs_url: str) -> str:
|
||||
"""
|
||||
Infer MIME type from GCS URL file extension.
|
||||
|
||||
Args:
|
||||
gcs_url: GCS URL like gs://bucket/path/to/file.png
|
||||
|
||||
Returns:
|
||||
str: Inferred MIME type
|
||||
|
||||
Raises:
|
||||
ValueError: If file extension is not supported
|
||||
"""
|
||||
extension_to_mime = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".mp3": "audio/mpeg",
|
||||
".wav": "audio/wav",
|
||||
".mp4": "video/mp4",
|
||||
".mov": "video/quicktime",
|
||||
".pdf": "application/pdf",
|
||||
}
|
||||
|
||||
gcs_url_lower = gcs_url.lower()
|
||||
for ext, mime_type in extension_to_mime.items():
|
||||
if gcs_url_lower.endswith(ext):
|
||||
return mime_type
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to infer MIME type from GCS URL: {gcs_url}. "
|
||||
f"Supported extensions: {', '.join(extension_to_mime.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def _parse_data_url(data_url: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse a data URL to extract the media type and base64 data.
|
||||
|
||||
Args:
|
||||
data_url: Data URL in format: data:image/jpeg;base64,/9j/4AAQ...
|
||||
|
||||
Returns:
|
||||
tuple: (media_type, base64_data)
|
||||
media_type: e.g., "image/jpeg", "video/mp4", "audio/mpeg"
|
||||
base64_data: The base64-encoded data without the prefix
|
||||
|
||||
Raises:
|
||||
ValueError: If data URL format is invalid or MIME type is unsupported
|
||||
"""
|
||||
if not data_url.startswith("data:"):
|
||||
raise ValueError(f"Invalid data URL format: {data_url[:50]}...")
|
||||
|
||||
if "," not in data_url:
|
||||
raise ValueError(f"Invalid data URL format (missing comma): {data_url[:50]}...")
|
||||
|
||||
metadata, base64_data = data_url.split(",", 1)
|
||||
|
||||
metadata = metadata[5:]
|
||||
|
||||
if ";" in metadata:
|
||||
media_type = metadata.split(";")[0]
|
||||
else:
|
||||
media_type = metadata
|
||||
|
||||
if media_type not in SUPPORTED_EMBEDDING_MIME_TYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported MIME type for embedding: {media_type}. "
|
||||
f"Supported types: {', '.join(sorted(SUPPORTED_EMBEDDING_MIME_TYPES))}"
|
||||
)
|
||||
|
||||
return media_type, base64_data
|
||||
|
||||
|
||||
def _is_multimodal_input(input: EmbeddingInput) -> bool:
|
||||
"""
|
||||
Check if the input contains multimodal data (data URIs, file references, or GCS URLs).
|
||||
|
||||
Args:
|
||||
input: EmbeddingInput (str or List[str])
|
||||
|
||||
Returns:
|
||||
bool: True if any element is a data URI, file reference, or GCS URL
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input_list = [input]
|
||||
else:
|
||||
input_list = input
|
||||
|
||||
for element in input_list:
|
||||
if isinstance(element, str):
|
||||
if element.startswith("data:") and ";base64," in element:
|
||||
return True
|
||||
if _is_file_reference(element):
|
||||
return True
|
||||
if _is_gcs_url(element):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def transform_openai_input_gemini_content(
|
||||
input: EmbeddingInput, model: str, optional_params: dict
|
||||
@@ -26,12 +148,17 @@ def transform_openai_input_gemini_content(
|
||||
The content to embed. Only the parts.text fields will be counted.
|
||||
"""
|
||||
gemini_model_name = "models/{}".format(model)
|
||||
|
||||
gemini_params = optional_params.copy()
|
||||
if "dimensions" in gemini_params:
|
||||
gemini_params["outputDimensionality"] = gemini_params.pop("dimensions")
|
||||
|
||||
requests: List[EmbedContentRequest] = []
|
||||
if isinstance(input, str):
|
||||
request = EmbedContentRequest(
|
||||
model=gemini_model_name,
|
||||
content=ContentType(parts=[PartType(text=input)]),
|
||||
**optional_params
|
||||
**gemini_params
|
||||
)
|
||||
requests.append(request)
|
||||
else:
|
||||
@@ -39,13 +166,119 @@ def transform_openai_input_gemini_content(
|
||||
request = EmbedContentRequest(
|
||||
model=gemini_model_name,
|
||||
content=ContentType(parts=[PartType(text=i)]),
|
||||
**optional_params
|
||||
**gemini_params
|
||||
)
|
||||
requests.append(request)
|
||||
|
||||
return VertexAIBatchEmbeddingsRequestBody(requests=requests)
|
||||
|
||||
|
||||
def transform_openai_input_gemini_embed_content(
|
||||
input: EmbeddingInput,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
resolved_files: Optional[Dict[str, Dict[str, str]]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform OpenAI embedding input to Gemini embedContent format (multimodal).
|
||||
|
||||
Args:
|
||||
input: EmbeddingInput (str or List[str]) with text, data URIs, or file references
|
||||
model: Model name
|
||||
optional_params: Additional parameters (taskType, outputDimensionality, etc.)
|
||||
resolved_files: Dict mapping file names (files/abc) to {mime_type, uri}
|
||||
|
||||
Returns:
|
||||
dict: Gemini embedContent request body with content.parts
|
||||
"""
|
||||
resolved_files = resolved_files or {}
|
||||
|
||||
gemini_params = optional_params.copy()
|
||||
if "dimensions" in gemini_params:
|
||||
gemini_params["outputDimensionality"] = gemini_params.pop("dimensions")
|
||||
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
parts: List[PartType] = []
|
||||
|
||||
for element in input_list:
|
||||
if not isinstance(element, str):
|
||||
raise ValueError(f"Unsupported input type: {type(element)}")
|
||||
|
||||
if element.startswith("data:") and ";base64," in element:
|
||||
mime_type, base64_data = _parse_data_url(element)
|
||||
blob: BlobType = {"mime_type": mime_type, "data": base64_data}
|
||||
parts.append(PartType(inline_data=blob))
|
||||
elif _is_gcs_url(element):
|
||||
mime_type = _infer_mime_type_from_gcs_url(element)
|
||||
file_data: FileDataType = {
|
||||
"mime_type": mime_type,
|
||||
"file_uri": element,
|
||||
}
|
||||
parts.append(PartType(file_data=file_data))
|
||||
elif _is_file_reference(element):
|
||||
if element not in resolved_files:
|
||||
raise ValueError(f"File reference {element} not resolved")
|
||||
file_info = resolved_files[element]
|
||||
file_data_ref: FileDataType = {
|
||||
"mime_type": file_info["mime_type"],
|
||||
"file_uri": file_info["uri"],
|
||||
}
|
||||
parts.append(PartType(file_data=file_data_ref))
|
||||
else:
|
||||
parts.append(PartType(text=element))
|
||||
|
||||
request_body: dict = {
|
||||
"content": ContentType(parts=parts),
|
||||
**gemini_params,
|
||||
}
|
||||
|
||||
return request_body
|
||||
|
||||
|
||||
def process_embed_content_response(
|
||||
input: EmbeddingInput,
|
||||
model_response: EmbeddingResponse,
|
||||
model: str,
|
||||
response_json: dict,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Process Gemini embedContent response (single embedding for multimodal input).
|
||||
|
||||
Args:
|
||||
input: Original input
|
||||
model_response: EmbeddingResponse to populate
|
||||
model: Model name
|
||||
response_json: Raw JSON response from embedContent endpoint
|
||||
|
||||
Returns:
|
||||
EmbeddingResponse with single embedding
|
||||
"""
|
||||
if "embedding" not in response_json:
|
||||
raise ValueError(f"embedContent response missing 'embedding' field: {response_json}")
|
||||
|
||||
embedding_data = response_json["embedding"]
|
||||
|
||||
openai_embedding = Embedding(
|
||||
embedding=embedding_data["values"],
|
||||
index=0,
|
||||
object="embedding",
|
||||
)
|
||||
|
||||
model_response.data = [openai_embedding]
|
||||
model_response.model = model
|
||||
|
||||
if _is_multimodal_input(input):
|
||||
prompt_tokens = 0
|
||||
else:
|
||||
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
|
||||
prompt_tokens = token_counter(model=model, text=input_text)
|
||||
model_response.usage = Usage(
|
||||
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
def process_response(
|
||||
input: EmbeddingInput,
|
||||
model_response: EmbeddingResponse,
|
||||
|
||||
+132
-2
@@ -132,6 +132,7 @@ from litellm.utils import (
|
||||
create_tokenizer,
|
||||
get_api_key,
|
||||
get_llm_provider,
|
||||
get_model_info,
|
||||
get_non_default_completion_params,
|
||||
get_non_default_transcription_params,
|
||||
get_optional_params_embeddings,
|
||||
@@ -5194,13 +5195,37 @@ def embedding( # noqa: PLR0915
|
||||
or get_secret_str("VERTEX_API_BASE")
|
||||
)
|
||||
|
||||
if (
|
||||
try:
|
||||
model_info = get_model_info(model=model, custom_llm_provider="vertex_ai")
|
||||
uses_embed_content = model_info.get("uses_embed_content", False)
|
||||
except Exception:
|
||||
uses_embed_content = False
|
||||
|
||||
if uses_embed_content:
|
||||
response = google_batch_embeddings.batch_embeddings( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=_get_encoding(),
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
aembedding=aembedding,
|
||||
print_verbose=print_verbose,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_key=None,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
extra_headers=headers,
|
||||
)
|
||||
elif (
|
||||
"image" in optional_params
|
||||
or "video" in optional_params
|
||||
or model
|
||||
in vertex_multimodal_embedding.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
|
||||
):
|
||||
# multimodal embedding is supported on vertex httpx
|
||||
response = vertex_multimodal_embedding.multimodal_embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
@@ -7575,6 +7600,111 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||
)
|
||||
|
||||
|
||||
########## Token Counting API ##########
|
||||
|
||||
|
||||
async def acount_tokens(
|
||||
model: str,
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> "TokenCountResponse":
|
||||
"""
|
||||
Count tokens for a given model and messages using provider-specific APIs.
|
||||
|
||||
Routes to the appropriate provider's token counting API (OpenAI, Anthropic, etc.)
|
||||
for exact token counts. Falls back to local tiktoken-based counting for unsupported providers.
|
||||
|
||||
Args:
|
||||
model: The model identifier (e.g., "openai/gpt-4o", "anthropic/claude-3-5-sonnet-20241022")
|
||||
messages: The messages to count tokens for (standard chat format)
|
||||
tools: Optional tools/functions to include in token count
|
||||
system: Optional system message/instructions
|
||||
api_key: Optional API key (falls back to environment variable)
|
||||
api_base: Optional custom API base URL
|
||||
|
||||
Returns:
|
||||
TokenCountResponse with total_tokens and metadata
|
||||
"""
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.types.utils import LlmProviders, TokenCountResponse
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
# Determine provider from model string
|
||||
resolved_model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
||||
get_llm_provider(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
)
|
||||
|
||||
# Use dynamic key/base if not explicitly provided
|
||||
if api_key is None:
|
||||
api_key = dynamic_api_key
|
||||
if api_base is None:
|
||||
api_base = dynamic_api_base
|
||||
|
||||
# Build deployment dict for the token counter
|
||||
deployment: Dict[str, Any] = {
|
||||
"litellm_params": {
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
}
|
||||
}
|
||||
|
||||
# Try to get provider-specific token counter
|
||||
try:
|
||||
llm_provider_enum = LlmProviders(custom_llm_provider)
|
||||
provider_model_info = ProviderConfigManager.get_provider_model_info(
|
||||
model=model, provider=llm_provider_enum
|
||||
)
|
||||
|
||||
if provider_model_info is not None:
|
||||
token_counter_instance = provider_model_info.get_token_counter()
|
||||
if (
|
||||
token_counter_instance is not None
|
||||
and token_counter_instance.should_use_token_counting_api(
|
||||
custom_llm_provider
|
||||
)
|
||||
):
|
||||
result = await token_counter_instance.count_tokens(
|
||||
model_to_use=resolved_model,
|
||||
messages=messages,
|
||||
contents=None,
|
||||
deployment=deployment,
|
||||
request_model=model,
|
||||
tools=tools,
|
||||
system=system,
|
||||
)
|
||||
if result is not None and not result.error:
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Provider token counting failed for model={model}, falling back to local: {e}"
|
||||
)
|
||||
|
||||
# Fallback to local tiktoken-based token counting
|
||||
fallback_messages = messages or []
|
||||
if system and fallback_messages:
|
||||
fallback_messages = [{"role": "system", "content": system}] + fallback_messages
|
||||
local_count = litellm.token_counter(
|
||||
model=model,
|
||||
messages=fallback_messages,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return TokenCountResponse(
|
||||
total_tokens=local_count,
|
||||
request_model=model,
|
||||
model_used=resolved_model,
|
||||
tokenizer_type="local_tokenizer",
|
||||
)
|
||||
|
||||
|
||||
# Cache for encoding to avoid repeated __getattr__ calls
|
||||
_encoding_cache: Optional[Any] = None
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -104,6 +104,50 @@ def encrypt_credentials(
|
||||
value=client_secret,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
# AWS SigV4 credential fields
|
||||
aws_access_key_id = credentials.get("aws_access_key_id")
|
||||
if aws_access_key_id is not None:
|
||||
credentials["aws_access_key_id"] = encrypt_value_helper(
|
||||
value=aws_access_key_id,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
aws_secret_access_key = credentials.get("aws_secret_access_key")
|
||||
if aws_secret_access_key is not None:
|
||||
credentials["aws_secret_access_key"] = encrypt_value_helper(
|
||||
value=aws_secret_access_key,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
aws_session_token = credentials.get("aws_session_token")
|
||||
if aws_session_token is not None:
|
||||
credentials["aws_session_token"] = encrypt_value_helper(
|
||||
value=aws_session_token,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
# aws_region_name and aws_service_name are NOT secrets — stored as-is
|
||||
return credentials
|
||||
|
||||
|
||||
def decrypt_credentials(
|
||||
credentials: MCPCredentials,
|
||||
) -> MCPCredentials:
|
||||
"""Decrypt all secret fields in an MCPCredentials dict using the global salt key."""
|
||||
secret_fields = [
|
||||
"auth_value",
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
]
|
||||
for field in secret_fields:
|
||||
value = credentials.get(field)
|
||||
if value is not None:
|
||||
credentials[field] = decrypt_value_helper(
|
||||
value=value,
|
||||
key=field,
|
||||
exception_type="debug",
|
||||
return_original_value=True,
|
||||
)
|
||||
return credentials
|
||||
|
||||
|
||||
@@ -354,9 +398,57 @@ async def update_mcp_server(
|
||||
"""
|
||||
Update a new mcp server record in the db
|
||||
"""
|
||||
import json
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
# Use helper to prepare data with proper JSON serialization
|
||||
data_dict = _prepare_mcp_server_data(data)
|
||||
|
||||
# Pre-fetch existing record once if we need it for auth_type or credential logic
|
||||
existing = None
|
||||
has_credentials = "credentials" in data_dict and data_dict["credentials"] is not None
|
||||
if data.auth_type or has_credentials:
|
||||
existing = await prisma_client.db.litellm_mcpservertable.find_unique(
|
||||
where={"server_id": data.server_id}
|
||||
)
|
||||
|
||||
# Clear stale credentials when auth_type changes but no new credentials provided
|
||||
if (
|
||||
data.auth_type
|
||||
and "credentials" not in data_dict
|
||||
and existing
|
||||
and existing.auth_type is not None
|
||||
and existing.auth_type != data.auth_type
|
||||
):
|
||||
data_dict["credentials"] = None
|
||||
|
||||
# Merge credentials: preserve existing fields not present in the update.
|
||||
# Without this, a partial credential update (e.g. changing only region)
|
||||
# would wipe encrypted secrets that the UI cannot display back.
|
||||
if "credentials" in data_dict and data_dict["credentials"] is not None:
|
||||
if existing and existing.credentials:
|
||||
# Only merge when auth_type is unchanged. Switching auth types
|
||||
# (e.g. oauth2 → api_key) should replace credentials entirely
|
||||
# to avoid stale secrets from the previous auth type lingering.
|
||||
auth_type_unchanged = (
|
||||
data.auth_type is None or data.auth_type == existing.auth_type
|
||||
)
|
||||
if auth_type_unchanged:
|
||||
existing_creds = (
|
||||
json.loads(existing.credentials)
|
||||
if isinstance(existing.credentials, str)
|
||||
else dict(existing.credentials)
|
||||
)
|
||||
new_creds = (
|
||||
json.loads(data_dict["credentials"])
|
||||
if isinstance(data_dict["credentials"], str)
|
||||
else dict(data_dict["credentials"])
|
||||
)
|
||||
# New values override existing; existing keys not in update are preserved
|
||||
merged = {**existing_creds, **new_creds}
|
||||
data_dict["credentials"] = safe_dumps(merged)
|
||||
|
||||
# Add audit fields
|
||||
data_dict["updated_by"] = touched_by
|
||||
|
||||
@@ -378,8 +470,12 @@ async def rotate_mcp_server_credentials_master_key(
|
||||
continue
|
||||
|
||||
credentials_copy = dict(credentials)
|
||||
encrypted_credentials = encrypt_credentials(
|
||||
# Decrypt with current key first, then re-encrypt with new key
|
||||
decrypted_credentials = decrypt_credentials(
|
||||
credentials=cast(MCPCredentials, credentials_copy),
|
||||
)
|
||||
encrypted_credentials = encrypt_credentials(
|
||||
credentials=decrypted_credentials,
|
||||
encryption_key=new_master_key,
|
||||
)
|
||||
|
||||
|
||||
@@ -597,9 +597,10 @@ class MCPServerManager:
|
||||
else:
|
||||
client_secret_value = encrypted_client_secret
|
||||
|
||||
# TODO: Add AWS SigV4 credential decryption here when DB-stored
|
||||
# SigV4 MCP servers are supported. Requires corresponding changes
|
||||
# to encrypt_credentials() in db.py and MCPCredentials TypedDict.
|
||||
# AWS SigV4 credential fields
|
||||
aws_creds = self._extract_aws_credentials(
|
||||
credentials_dict, credentials_are_encrypted
|
||||
)
|
||||
|
||||
scopes: Optional[List[str]] = None
|
||||
if credentials_dict:
|
||||
@@ -679,6 +680,12 @@ class MCPServerManager:
|
||||
is_byok=bool(getattr(mcp_server, "is_byok", False)),
|
||||
byok_description=getattr(mcp_server, "byok_description", None) or [],
|
||||
byok_api_key_help_url=getattr(mcp_server, "byok_api_key_help_url", None),
|
||||
# AWS SigV4 fields
|
||||
aws_access_key_id=aws_creds.get("aws_access_key_id"),
|
||||
aws_secret_access_key=aws_creds.get("aws_secret_access_key"),
|
||||
aws_session_token=aws_creds.get("aws_session_token"),
|
||||
aws_region_name=aws_creds.get("aws_region_name"),
|
||||
aws_service_name=aws_creds.get("aws_service_name"),
|
||||
)
|
||||
return new_server
|
||||
|
||||
@@ -1520,6 +1527,52 @@ class MCPServerManager:
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _decrypt_credential_field(
|
||||
encrypted_value: Optional[str],
|
||||
key: str,
|
||||
credentials_are_encrypted: bool,
|
||||
) -> Optional[str]:
|
||||
"""Decrypt a single credential field, or return as-is if not encrypted."""
|
||||
if not encrypted_value:
|
||||
return None
|
||||
if credentials_are_encrypted:
|
||||
return decrypt_value_helper(
|
||||
value=encrypted_value,
|
||||
key=key,
|
||||
exception_type="debug",
|
||||
return_original_value=True,
|
||||
)
|
||||
return encrypted_value
|
||||
|
||||
def _extract_aws_credentials(
|
||||
self,
|
||||
credentials_dict: Optional[Dict[str, str]],
|
||||
credentials_are_encrypted: bool,
|
||||
) -> Dict[str, Optional[str]]:
|
||||
"""Extract and decrypt AWS SigV4 credential fields from credentials dict."""
|
||||
if not credentials_dict:
|
||||
return {}
|
||||
return {
|
||||
"aws_access_key_id": self._decrypt_credential_field(
|
||||
credentials_dict.get("aws_access_key_id"),
|
||||
"aws_access_key_id",
|
||||
credentials_are_encrypted,
|
||||
),
|
||||
"aws_secret_access_key": self._decrypt_credential_field(
|
||||
credentials_dict.get("aws_secret_access_key"),
|
||||
"aws_secret_access_key",
|
||||
credentials_are_encrypted,
|
||||
),
|
||||
"aws_session_token": self._decrypt_credential_field(
|
||||
credentials_dict.get("aws_session_token"),
|
||||
"aws_session_token",
|
||||
credentials_are_encrypted,
|
||||
),
|
||||
"aws_region_name": credentials_dict.get("aws_region_name"),
|
||||
"aws_service_name": credentials_dict.get("aws_service_name"),
|
||||
}
|
||||
|
||||
def _extract_scopes(self, scopes_value: Any) -> Optional[List[str]]:
|
||||
if isinstance(scopes_value, str):
|
||||
scopes = [s.strip() for s in scopes_value.split() if s.strip()]
|
||||
|
||||
@@ -93,25 +93,7 @@ def get_base_url(spec: Dict[str, Any], spec_path: Optional[str] = None) -> str:
|
||||
"""Extract base URL from OpenAPI spec."""
|
||||
# OpenAPI 3.x
|
||||
if "servers" in spec and spec["servers"]:
|
||||
server_url = spec["servers"][0]["url"]
|
||||
|
||||
# If the server URL is relative (starts with /), derive base from spec_path
|
||||
if server_url.startswith("/") and spec_path:
|
||||
if spec_path.startswith("http://") or spec_path.startswith("https://"):
|
||||
# Extract base URL from spec_path (e.g., https://petstore3.swagger.io/api/v3/openapi.json)
|
||||
# Combine domain with the relative server URL
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(spec_path)
|
||||
base_domain = f"{parsed.scheme}://{parsed.netloc}"
|
||||
full_base_url = base_domain + server_url
|
||||
verbose_logger.info(
|
||||
f"OpenAPI spec has relative server URL '{server_url}'. "
|
||||
f"Deriving base from spec_path: {full_base_url}"
|
||||
)
|
||||
return full_base_url
|
||||
|
||||
return server_url
|
||||
return spec["servers"][0]["url"]
|
||||
# OpenAPI 2.x (Swagger)
|
||||
elif "host" in spec:
|
||||
scheme = spec.get("schemes", ["https"])[0]
|
||||
|
||||
@@ -718,7 +718,6 @@ if MCP_AVAILABLE:
|
||||
|
||||
Checks both the full tool name and unprefixed version (without server prefix).
|
||||
This allows users to configure simple tool names regardless of prefixing.
|
||||
Comparison is case-insensitive to handle OpenAPI operationIds that may be in camelCase.
|
||||
|
||||
Args:
|
||||
tool_name: The tool name to check (may be prefixed like "server-tool_name")
|
||||
@@ -731,15 +730,13 @@ if MCP_AVAILABLE:
|
||||
split_server_prefix_from_name,
|
||||
)
|
||||
|
||||
# Normalize filter list to lowercase for case-insensitive comparison
|
||||
filter_list_lower = [f.lower() for f in filter_list]
|
||||
|
||||
if tool_name.lower() in filter_list_lower:
|
||||
# Check if the full name is in the list
|
||||
if tool_name in filter_list:
|
||||
return True
|
||||
|
||||
# Check if the unprefixed name is in the list (case-insensitive)
|
||||
# Check if the unprefixed name is in the list
|
||||
unprefixed_name, _ = split_server_prefix_from_name(tool_name)
|
||||
return unprefixed_name.lower() in filter_list_lower
|
||||
return unprefixed_name in filter_list
|
||||
|
||||
def filter_tools_by_allowed_tools(
|
||||
tools: List[MCPTool],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user