Merge remote-tracking branch 'origin/litellm_internal_staging' into litellm_yj_apr15

# Conflicts:
#	litellm/litellm_core_utils/litellm_logging.py
#	uv.lock
This commit is contained in:
Yuneng Jiang
2026-04-16 09:17:20 -07:00
104 changed files with 6753 additions and 2077 deletions
-152
View File
@@ -2916,90 +2916,6 @@ jobs:
- codecov/upload:
file: ./coverage.xml
publish_proxy_extras:
docker:
- image: cimg/python:3.12
working_directory: ~/project/litellm-proxy-extras
environment:
TWINE_USERNAME: __token__
steps:
- checkout:
path: ~/project
- run:
name: Check if litellm-proxy-extras dir or pyproject.toml was modified
command: |
curl -LsSf -o /tmp/uv-install.sh https://astral.sh/uv/0.10.9/install.sh
echo "7fc46e39cb97290b57169c0c813a17970585ac519139f19006453c99b5f2f45f /tmp/uv-install.sh" | sha256sum -c -
env UV_NO_MODIFY_PATH=1 sh /tmp/uv-install.sh
rm -f /tmp/uv-install.sh
echo 'export PATH="$HOME/.local/bin:$PATH"' >> "$BASH_ENV"
echo 'export PATH="$HOME/.local/bin:$PATH"' >> "$BASH_ENV"
export PATH="$HOME/.local/bin:$PATH"
# Get current version from pyproject.toml
CURRENT_VERSION=$(python -c 'import tomllib; from pathlib import Path; data = tomllib.loads(Path("pyproject.toml").read_text()); print(data["project"]["version"])')
# Get last published version from PyPI
LAST_VERSION=$(curl -s https://pypi.org/pypi/litellm-proxy-extras/json | python -c "import json, sys; print(json.load(sys.stdin)['info']['version'])")
echo "Current version: $CURRENT_VERSION"
echo "Last published version: $LAST_VERSION"
# Compare versions using Python's packaging.version
VERSION_COMPARE=$(uv run --with 'packaging==25.0' python -c "from packaging import version; print(1 if version.parse('$CURRENT_VERSION') < version.parse('$LAST_VERSION') else 0)")
echo "Version compare: $VERSION_COMPARE"
if [ "$VERSION_COMPARE" = "1" ]; then
echo "Error: Current version ($CURRENT_VERSION) is less than last published version ($LAST_VERSION)"
exit 1
fi
# If versions are equal or current is greater, compare against the published package contents.
EXTRACTED_DIR=$(uv run --with "litellm-proxy-extras==$LAST_VERSION" python -c 'import importlib.util; from pathlib import Path; spec = importlib.util.find_spec("litellm_proxy_extras"); assert spec is not None and spec.origin is not None, "litellm_proxy_extras not found in uv-run environment"; print(Path(spec.origin).resolve().parent)')
# Compare contents
if ! diff -r "$EXTRACTED_DIR" ./litellm_proxy_extras; then
if [ "$CURRENT_VERSION" = "$LAST_VERSION" ]; then
echo "Error: Changes detected in litellm-proxy-extras but version was not bumped"
echo "Current version: $CURRENT_VERSION"
echo "Last published version: $LAST_VERSION"
echo "Changes:"
diff -r "$EXTRACTED_DIR" ./litellm_proxy_extras
exit 1
fi
else
echo "No changes detected in litellm-proxy-extras. Skipping PyPI publish."
circleci step halt
fi
- run:
name: Get new version
command: |
NEW_VERSION=$(python -c 'import tomllib; from pathlib import Path; data = tomllib.loads(Path("pyproject.toml").read_text()); print(data["project"]["version"])')
echo "export NEW_VERSION=$NEW_VERSION" >> $BASH_ENV
- run:
name: Check if versions match
command: |
cd ~/project
# Check pyproject.toml
CURRENT_VERSION=$(uv run --with 'packaging==25.0' python -c 'import tomllib; from packaging.requirements import Requirement; from pathlib import Path; data = tomllib.loads(Path("pyproject.toml").read_text()); matches = [spec.version for requirement in data["project"]["optional-dependencies"]["proxy"] for parsed in [Requirement(requirement)] if parsed.name == "litellm-proxy-extras" and parsed.specifier for spec in parsed.specifier if spec.operator == "=="]; print(matches[0] if matches else (_ for _ in ()).throw(SystemExit("Could not find exact litellm-proxy-extras pin in project.optional-dependencies.proxy")))')
if [ "$CURRENT_VERSION" != "$NEW_VERSION" ]; then
echo "Error: Version in pyproject.toml ($CURRENT_VERSION) doesn't match new version ($NEW_VERSION)"
exit 1
fi
- run:
name: Publish to PyPI
command: |
echo -e "[pypi]\nusername = $PYPI_PUBLISH_USERNAME\npassword = $PYPI_PUBLISH_PASSWORD" > ~/.pypirc
echo 'export PATH="$HOME/.local/bin:$PATH"' >> "$BASH_ENV"
export PATH="$HOME/.local/bin:$PATH"
rm -rf build dist
uv build
uv tool run --from 'twine==6.2.0' twine upload --verbose dist/*
ui_build:
docker:
- image: cimg/node:20.19
@@ -3214,60 +3130,6 @@ jobs:
- litellm-docker-database.tar.zst
prisma_schema_sync:
machine:
image: ubuntu-2204:2023.10.1
resource_class: medium
working_directory: ~/project
steps:
- checkout
- setup_google_dns
- attach_workspace:
at: ~/project
- run:
name: Start PostgreSQL Database
command: |
docker run -d \
--name postgres-db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=postgres \
-e POSTGRES_DB=litellm_schema_sync \
-p 5432:5432 \
postgres:14
- wait_for_service:
url: tcp://localhost:5432
timeout: "60"
- run:
name: Load Docker Database Image
command: |
zstd -d litellm-docker-database.tar.zst --stdout | docker load
docker images | grep litellm-docker-database
- run:
name: Run schema sync via prisma db push
command: |
docker run -d \
-p 4000:4000 \
-e DATABASE_URL="postgresql://postgres:postgres@host.docker.internal:5432/litellm_schema_sync" \
-e LITELLM_MASTER_KEY="sk-1234" \
--name schema-sync \
--add-host=host.docker.internal:host-gateway \
-v $(pwd)/litellm/proxy/example_config_yaml/simple_config.yaml:/app/config.yaml \
litellm-docker-database:ci \
--config /app/config.yaml \
--port 4000 \
--use_prisma_db_push
- run:
name: Start outputting logs
command: docker logs -f schema-sync
background: true
- wait_for_service:
url: http://localhost:4000
timeout: "300"
- run:
name: Stop schema sync container
command: docker stop schema-sync
test_bad_database_url:
machine:
image: ubuntu-2204:2023.10.1
@@ -3421,14 +3283,6 @@ workflows:
only:
- main
- /litellm_.*/
- prisma_schema_sync:
requires:
- build_docker_database_image
filters:
branches:
only:
- main
- /litellm_.*/
- e2e_ui_testing:
filters:
branches:
@@ -3688,9 +3542,3 @@ workflows:
only:
- main
- /litellm_.*/
- publish_proxy_extras:
filters:
branches:
only:
- main
- /litellm_release_day_.*/
@@ -487,6 +487,7 @@ router_settings:
| AZURE_STORAGE_CLIENT_ID | The Application Client ID to use for Authentication to Azure Blob Storage logging
| AZURE_STORAGE_CLIENT_SECRET | The Application Client Secret to use for Authentication to Azure Blob Storage logging
| AZURE_VECTOR_STORE_COST_PER_GB_PER_DAY | Cost per GB per day for Azure Vector Store service
| BACKGROUND_HEALTH_CHECK_MAX_TOKENS | Optional global default for `max_tokens` on proxy background health checks when a model has no `health_check_max_tokens`. If unset, non-wildcard models default to 1. Applies to wildcard routes when set. Default is unset
| BATCH_STATUS_POLL_INTERVAL_SECONDS | Interval in seconds for polling batch status. Default is 3600 (1 hour)
| BATCH_STATUS_POLL_MAX_ATTEMPTS | Maximum number of attempts for polling batch status. Default is 24 (for 24 hours)
| BEDROCK_MAX_POLICY_SIZE | Maximum size for Bedrock policy. Default is 75
@@ -804,6 +805,8 @@ router_settings:
| LITELLM_ASSETS_PATH | Path to directory for UI assets and logos. Used when running with read-only filesystem (e.g., Kubernetes). Default is `/var/lib/litellm/assets` in Docker.
| LITELLM_BLOG_POSTS_URL | Custom URL for fetching LiteLLM blog posts JSON. Default is the GitHub main branch URL
| LITELLM_CLI_JWT_EXPIRATION_HOURS | Expiration time in hours for CLI-generated JWT tokens. Default is 24 hours
| LITELLM_CORS_ALLOW_CREDENTIALS | Set to `true` to explicitly allow credentials in CORS responses. When not set, credentials are disabled automatically if `LITELLM_CORS_ORIGINS` is `*` (wildcard) to prevent the browser security misconfiguration of reflecting any origin with credentials
| LITELLM_CORS_ORIGINS | Comma-separated list of allowed CORS origins (e.g. `https://app.example.com,https://admin.example.com`). Defaults to `*` (all origins) when not set
| LITELLM_DD_AGENT_HOST | Hostname or IP of DataDog agent for LiteLLM-specific logging. When set, logs are sent to agent instead of direct API
| LITELLM_DEPLOYMENT_ENVIRONMENT | Environment name for the deployment (e.g., "production", "staging"). Used as a fallback when OTEL_ENVIRONMENT_NAME is not set. Sets the `environment` tag in telemetry data
| LITELLM_DETAILED_TIMING | When true, adds detailed per-phase timing headers to responses (`x-litellm-timing-{pre-processing,llm-api,post-processing,message-copy}-ms`). Default is false. See [latency overhead docs](../troubleshoot/latency_overhead.md)
@@ -925,6 +928,7 @@ router_settings:
| OPENAI_CHATGPT_API_BASE | Alternative to CHATGPT_API_BASE. Base URL for ChatGPT API
| OPENAI_FILE_SEARCH_COST_PER_1K_CALLS | Cost per 1000 calls for OpenAI file search. Default is 0.0025
| OPENAI_ORGANIZATION | Organization identifier for OpenAI
| OPENAPI_URL | The path to the OpenAPI JSON endpoint. **By default this is "/openapi.json"**
| OPENID_BASE_URL | Base URL for OpenID Connect services
| OPENID_CLIENT_ID | Client ID for OpenID Connect authentication
| OPENID_CLIENT_SECRET | Client secret for OpenID Connect authentication
@@ -14,6 +14,10 @@ Provider-specific cost tracking (e.g., [Vertex AI PayGo / priority pricing](../p
[Sync model pricing data from GitHub](./sync_models_github.md) to ensure accurate cost tracking.
:::
:::info Cost does not match your provider bill?
Use the step-by-step workflow in [Debugging a cost discrepancy](../troubleshoot/cost_discrepancy): align time ranges, compare token categories (including cache), then decide whether the gap is ingestion, formula, or model-map pricing.
:::
### How to Track Spend with LiteLLM
**Step 1**
@@ -2,6 +2,16 @@ import Image from '@theme/IdealImage';
# Team Soft Budget Alerts
:::info
✨ This is an Enterprise feature. Email budget alerts require an enterprise license.
[Enterprise Pricing](https://www.litellm.ai/#pricing)
[Get free 7-day trial key](https://www.litellm.ai/enterprise#trial)
:::
Set a soft budget on a team and get email alerts when spending crosses the threshold — without blocking any requests.
## Overview
@@ -0,0 +1,205 @@
# Debugging a cost discrepancy
Cost discrepancies between LiteLLM and your provider bill usually come from one of three areas: token ingestion, the cost formula LiteLLM applies, or stale or incorrect pricing in the model map. This page walks through how to tell which case you are in.
## Step 1: Pick a time range
Lock down a specific window where the discrepancy is visible.
- Use at least 7 days of data when you can.
- Prefer a window with stable usage so one-off spikes do not dominate the comparison.
- Set the **same start and end time** on both your provider dashboard and the LiteLLM UI.
![LiteLLM dashboard date range picker](/img/cost-discrepancy-debug/date-range-picker.png)
## Step 2: Confirm traffic only goes through LiteLLM
If any requests hit the provider directly (bypassing LiteLLM), the provider will show higher usage. That is expected, not a LiteLLM bug.
Before continuing, confirm:
- All clients use your LiteLLM proxy base URL.
- No SDK or script uses provider API keys against the provider directly for the models you are comparing.
- During the selected period, the models in question are only called via LiteLLM.
If you are unsure, filter the provider dashboard by the API key or IAM principal LiteLLM uses, rather than comparing to your whole account.
## Step 3: Compare token categories
In the LiteLLM UI, open **Model activity** (under Usage analytics) so you can inspect spend and tokens per model.
![Navigate to Model activity in the LiteLLM UI](/img/cost-discrepancy-debug/go-to-model-activity.png)
Scroll the **Model** list and select the model you are reconciling with your provider bill.
![Scroll to your model in the Model activity table](/img/cost-discrepancy-debug/scroll-to-model.png)
With the same time range on both sides, fill in:
| Category | LiteLLM | Provider | Delta |
| --- | --- | --- | --- |
| Total requests | — | — | — |
| Input tokens | — | — | — |
| Output tokens | — | — | — |
| Cache read tokens | — | — | — |
| Cache write tokens | — | — | — |
LiteLLM surfaces per-category token usage for the selected model—for example prompt, completion, and cache-related tokens.
![LiteLLM usage breakdown by token category](/img/cost-discrepancy-debug/token-categories.png)
Compare these figures with your providers usage view (for example AWS billing tools, Azure Monitor, or the OpenAI usage dashboard) for the same period.
### Cache token reporting
- **OpenAI:** Cache read tokens are typically included inside the reported input token count.
- **Anthropic:** Cache read tokens are often reported separately from non-cached input tokens.
Compare the correct columns on each side so you are not treating “input” differently between dashboards.
### Why use a 10% threshold?
Provider dashboards and LiteLLM do not bucket requests on identical timestamps. A call at 11:59 PM can land in different daily totals on each side. Token counts can also differ slightly due to rounding across SDKs and APIs. A delta **under ~10%** is often explained by boundary effects and rounding. A delta **over ~10%** usually means something is miscounted, dropped, or categorized differently.
## Step 4: Follow the right path
<svg width="100%" viewBox="0 0 680 482" role="img" xmlns="http://www.w3.org/2000/svg" style={{ maxWidth: '100%', fontFamily: 'system-ui, sans-serif' }} aria-labelledby="cost-disc-flow-title">
<title id="cost-disc-flow-title">Cost discrepancy debugging flowchart</title>
<desc>Flowchart branching into Path A (token ingestion) or Path B which splits further into B1 (formula issue) and B2 (model map issue).</desc>
<defs>
<marker id="cd-arrow" viewBox="0 0 10 10" refX="8" refY="5" markerWidth="6" markerHeight="6" orient="auto-start-reverse">
<path d="M2 1L8 5L2 9" fill="none" stroke="#888780" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
</marker>
</defs>
<rect x="215" y="24" width="250" height="44" rx="8" fill="#F1EFE8" stroke="#5F5E5A" strokeWidth="0.5" />
<text x="340" y="47" textAnchor="middle" dominantBaseline="central" fill="#444441" fontSize="14" fontWeight="500">Compare provider vs LiteLLM</text>
<line x1="340" y1="68" x2="340" y2="104" stroke="#888780" strokeWidth="1.5" markerEnd="url(#cd-arrow)" />
<rect x="175" y="104" width="330" height="56" rx="8" fill="#F1EFE8" stroke="#5F5E5A" strokeWidth="0.5" />
<text x="340" y="126" textAnchor="middle" dominantBaseline="central" fill="#444441" fontSize="14" fontWeight="500">Any category off by &gt; 10%?</text>
<text x="340" y="148" textAnchor="middle" dominantBaseline="central" fill="#5F5E5A" fontSize="12">requests, input, output, cache tokens</text>
<path d="M220 132 L100 132 L100 250" fill="none" stroke="#0F6E56" strokeWidth="1.5" markerEnd="url(#cd-arrow)" />
<text x="157" y="122" textAnchor="middle" fill="#0F6E56" fontSize="12">YES</text>
<path d="M505 132 L580 132 L580 250" fill="none" stroke="#993C1D" strokeWidth="1.5" markerEnd="url(#cd-arrow)" />
<text x="543" y="122" textAnchor="middle" fill="#993C1D" fontSize="12">NO</text>
<rect x="40" y="250" width="220" height="56" rx="8" fill="#E1F5EE" stroke="#0F6E56" strokeWidth="0.5" />
<text x="150" y="271" textAnchor="middle" dominantBaseline="central" fill="#085041" fontSize="14" fontWeight="500">Path A</text>
<text x="150" y="291" textAnchor="middle" dominantBaseline="central" fill="#0F6E56" fontSize="12">Token ingestion issue</text>
<rect x="420" y="250" width="220" height="56" rx="8" fill="#FAECE7" stroke="#993C1D" strokeWidth="0.5" />
<text x="530" y="271" textAnchor="middle" dominantBaseline="central" fill="#712B13" fontSize="14" fontWeight="500">Path B</text>
<text x="530" y="291" textAnchor="middle" dominantBaseline="central" fill="#993C1D" fontSize="12">Quantities match, cost differs</text>
<line x1="150" y1="306" x2="150" y2="370" stroke="#0F6E56" strokeWidth="1.5" markerEnd="url(#cd-arrow)" />
<line x1="530" y1="306" x2="530" y2="318" stroke="#854F0B" strokeWidth="1.5" />
<line x1="435" y1="318" x2="575" y2="318" stroke="#854F0B" strokeWidth="1.5" />
<line x1="435" y1="318" x2="435" y2="370" stroke="#854F0B" strokeWidth="1.5" markerEnd="url(#cd-arrow)" />
<line x1="575" y1="318" x2="575" y2="370" stroke="#854F0B" strokeWidth="1.5" markerEnd="url(#cd-arrow)" />
<text x="448" y="312" textAnchor="middle" fill="#854F0B" fontSize="11">B1</text>
<text x="562" y="312" textAnchor="middle" fill="#854F0B" fontSize="11">B2</text>
<rect x="40" y="370" width="220" height="56" rx="8" fill="#E1F5EE" stroke="#0F6E56" strokeWidth="0.5" />
<text x="150" y="391" textAnchor="middle" dominantBaseline="central" fill="#085041" fontSize="14" fontWeight="500">Report to LiteLLM team</text>
<text x="150" y="411" textAnchor="middle" dominantBaseline="central" fill="#0F6E56" fontSize="12">endpoints + model + screenshots</text>
<rect x="380" y="370" width="110" height="56" rx="8" fill="#FAEEDA" stroke="#854F0B" strokeWidth="0.5" />
<text x="435" y="391" textAnchor="middle" dominantBaseline="central" fill="#633806" fontSize="14" fontWeight="500">B1</text>
<text x="435" y="411" textAnchor="middle" dominantBaseline="central" fill="#854F0B" fontSize="12">Fix formula</text>
<rect x="510" y="370" width="130" height="56" rx="8" fill="#FAEEDA" stroke="#854F0B" strokeWidth="0.5" />
<text x="575" y="391" textAnchor="middle" dominantBaseline="central" fill="#633806" fontSize="14" fontWeight="500">B2</text>
<text x="575" y="411" textAnchor="middle" dominantBaseline="central" fill="#854F0B" fontSize="12">Fix model map</text>
<path d="M150 426 L150 442 L340 442" fill="none" stroke="#888780" strokeWidth="0.5" strokeDasharray="4 3" />
<path d="M340 442 L435 442 L435 428" fill="none" stroke="#888780" strokeWidth="0.5" strokeDasharray="4 3" />
<path d="M340 442 L575 442 L575 428" fill="none" stroke="#888780" strokeWidth="0.5" strokeDasharray="4 3" />
<text x="340" y="454" textAnchor="middle" fill="#5F5E5A" fontSize="11">if neither path resolves it,</text>
<text x="340" y="470" textAnchor="middle" fill="#5F5E5A" fontSize="11">Open a github issue backing up with all your data</text>
</svg>
## Path A: Token quantity mismatch
If any category is off by more than about 10%, LiteLLM may not be ingesting that category correctly (or the provider dashboard is categorizing tokens differently—recheck Step 3 first).
**What to send the LiteLLM team:**
1. Screenshots of both dashboards with the date range visible.
2. Which category is off (input, output, cache reads, cache writes, or request count).
3. Endpoints used (for example `/chat/completions`, `/responses`, `/embeddings`).
4. Model names as sent in the request (for example `anthropic.claude-opus-4-5`, `gpt-4o`).
### For maintainers debugging ingestion
1. Start the proxy with verbose logging, for example:
```bash
litellm --config config.yaml --detailed_debug
```
2. Reproduce a single request with the reported endpoint and model.
3. Inspect the raw `usage` object in each streamed chunk (if streaming) or in the final response body.
4. Compare that to the standard logging object (or the UI request log for that call).
5. Any gap between raw provider usage and what LiteLLM logs or aggregates is where ingestion may be wrong.
## Path B: Quantities match but cost is wrong
If token and request counts agree within ~10% but dollar amounts differ, focus on how cost is computed.
### B1: Formula issue
Manually compute expected cost using the providers token breakdown and published rates (per million tokens or per token).
Add other billed dimensions your provider applies (for example cache creation, audio, or tier surcharges). If your hand calculation matches the provider bill but not LiteLLM, the implementation in LiteLLM for that provider or modality may be wrong.
### B2: Model map issue
If the formula structure matches how the provider bills, the values in LiteLLMs model map may be stale or incorrect. Cross-check:
- [`model_prices_and_context_window.json`](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
- The providers current public pricing
Inspect `input_cost_per_token`, `output_cost_per_token`, and any cache-related pricing fields for your exact model id (including provider prefix).
### For maintainers
1. Take authoritative token quantities from the users provider report.
2. Derive the formula that reproduces the providers line item.
3. Diff that against LiteLLMs cost path for the same provider and response shape.
4. If the formula matches but numbers differ, update pricing in `model_prices_and_context_window.json` (and follow the projects sync / backup rules for that file).
5. If the formula in code is wrong, fix the calculation and add a regression test using the users token breakdown.
## Still stuck?
1. Open a GitHub issue on [BerriAI/litellm](https://github.com/BerriAI/litellm) with your Step 3 comparison table, endpoints, and model names.
On the issue, it helps to clarify:
- Reproducible on demand or intermittent?
- Single model or many?
- Steady over time, or starting from a specific release date or config change?
### For LiteLLM maintainers
If Path A and Path B do not close the case after triage, **you** should reach out and **schedule a call with the customer** (support or engineering), with the Step 3 table and screenshots—before treating the issue.
## Checklist
```
□ Same time range on both dashboards
□ Confirmed no direct-to-provider traffic for those models
□ Compared: requests, input tokens, output tokens, cache tokens
□ Noted cache reporting differences (OpenAI vs Anthropic, and so on)
□ If > ~10% delta on quantities → Path A: report with screenshots, endpoints, model names
□ If quantities match → Path B: verify formula (B1) and model map pricing (B2)
□ If neither path fits → open a GitHub issue.
```
## See also
- [Spend tracking](../proxy/cost_tracking)
- [Sync model pricing from GitHub](../proxy/sync_models_github)
+1
View File
@@ -1149,6 +1149,7 @@ const sidebars = {
label: "Troubleshooting",
items: [
"troubleshoot/ui_issues",
"troubleshoot/cost_discrepancy",
"mcp_troubleshoot",
{
type: "category",
Binary file not shown.

After

Width:  |  Height:  |  Size: 509 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 445 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 296 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 281 KiB

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN IF NOT EXISTS "instructions" TEXT;
@@ -0,0 +1,12 @@
-- CreateIndex (CONCURRENTLY)
--
-- Disclaimer:
-- - CREATE INDEX CONCURRENTLY cannot run inside a transaction. This migration must stay a
-- single statement so Prisma Migrate on PostgreSQL can apply it outside a transaction.
-- - Builds are slower and use more I/O than a blocking CREATE INDEX; if the build is
-- interrupted, Postgres may leave an INVALID index that must be dropped and recreated.
-- - Do not edit this file after it has been applied to any database: Prisma checksums
-- migrations; add a new migration instead.
-- - Requires PostgreSQL that supports CONCURRENTLY with IF NOT EXISTS (use a new migration
-- without IF NOT EXISTS if you must support older versions).
CREATE INDEX CONCURRENTLY IF NOT EXISTS "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx" ON "LiteLLM_HealthCheckTable"("model_id", "model_name", "checked_at" DESC);
@@ -289,6 +289,7 @@ model LiteLLM_MCPServerTable {
server_name String?
alias String?
description String?
instructions String?
url String?
spec_path String?
transport String @default("sse")
@@ -1045,6 +1046,7 @@ model LiteLLM_HealthCheckTable {
@@index([model_name])
@@index([checked_at])
@@index([status])
@@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx")
}
// Search Tools table for storing search tool configurations
+31 -2
View File
@@ -1,6 +1,6 @@
import os
import sys
from typing import List, Literal
from typing import List, Literal, Optional
from litellm.litellm_core_utils.env_utils import get_env_int
@@ -413,7 +413,20 @@ MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = int(
)
DEFAULT_MAX_TOKENS_FOR_TRITON = int(os.getenv("DEFAULT_MAX_TOKENS_FOR_TRITON", 2000))
#### Networking settings ####
request_timeout: float = float(os.getenv("REQUEST_TIMEOUT", 6000)) # time in seconds
# Sentinel used when `REQUEST_TIMEOUT` is unset: `litellm.request_timeout` keeps this
# value so longer-running surfaces (Router `timeout or litellm.request_timeout`,
# speech/TTS, responses, vector stores, etc.) get a long HTTP deadline. Chat
# `completion()` maps this sentinel down to 600s when the caller did not set a
# per-request/model timeout—see ``CompletionTimeout.resolve`` in completion_timeout.py. MCP uses
# dedicated timeouts (e.g. `MCP_CLIENT_TIMEOUT`), not `request_timeout`.
DEFAULT_REQUEST_TIMEOUT_SECONDS: float = 6000.0
# Pair used for default httpx clients when no custom timeout is passed: read/write
# deadline and connect handshake (see ``http_handler`` cached handler paths).
COMPLETION_HTTP_FALLBACK_SECONDS: float = 600.0
HTTP_HANDLER_CONNECT_TIMEOUT_SECONDS: float = 5.0
request_timeout: float = float(
os.getenv("REQUEST_TIMEOUT", str(int(DEFAULT_REQUEST_TIMEOUT_SECONDS)))
)
DEFAULT_A2A_AGENT_TIMEOUT: float = float(
os.getenv("DEFAULT_A2A_AGENT_TIMEOUT", 6000)
) # 10 minutes
@@ -1330,6 +1343,22 @@ BATCH_STATUS_POLL_MAX_ATTEMPTS = int(
HEALTH_CHECK_TIMEOUT_SECONDS = int(
os.getenv("HEALTH_CHECK_TIMEOUT_SECONDS", 60)
) # 60 seconds
_background_health_check_max_tokens_env = os.getenv(
"BACKGROUND_HEALTH_CHECK_MAX_TOKENS"
)
try:
_raw_background_health_check_max_tokens = (
_background_health_check_max_tokens_env.strip()
if _background_health_check_max_tokens_env is not None
else ""
)
BACKGROUND_HEALTH_CHECK_MAX_TOKENS: Optional[int] = (
int(_raw_background_health_check_max_tokens)
if _raw_background_health_check_max_tokens
else None
)
except (ValueError, TypeError):
BACKGROUND_HEALTH_CHECK_MAX_TOKENS = None
LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME = "litellm-internal-health-check"
LITTELM_CLI_SERVICE_ACCOUNT_NAME = "litellm-cli"
LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME = "litellm_internal_jobs"
+8 -1
View File
@@ -221,6 +221,7 @@ class MCPClient:
self.extra_headers: Optional[Dict[str, str]] = extra_headers
self.ssl_verify: Optional[VerifyTypes] = ssl_verify
self._aws_auth: Optional[httpx.Auth] = aws_auth
self._last_initialize_instructions: Optional[str] = None
# handle the basic auth value if provided
if auth_value:
self.update_auth_value(auth_value)
@@ -296,7 +297,12 @@ class MCPClient:
session_ctx = ClientSession(read_stream, write_stream)
session = await session_ctx.__aenter__()
try:
await session.initialize()
init_result = await session.initialize()
self._last_initialize_instructions = None
if init_result is not None:
ins = getattr(init_result, "instructions", None)
if isinstance(ins, str) and ins.strip():
self._last_initialize_instructions = ins.strip()
return await operation(session)
finally:
try:
@@ -315,6 +321,7 @@ class MCPClient:
"""Open a session, run the provided coroutine, and clean up."""
http_client: Optional[httpx.AsyncClient] = None
try:
self._last_initialize_instructions = None
transport_ctx, http_client = self._create_transport_context()
return await self._execute_session_operation(transport_ctx, operation)
except Exception:
@@ -0,0 +1,83 @@
"""Completion HTTP timeout resolution (kept out of ``main.py`` to limit import cycles)."""
from __future__ import annotations
from typing import Callable, Optional, Union
import httpx
from litellm.constants import (
COMPLETION_HTTP_FALLBACK_SECONDS,
DEFAULT_REQUEST_TIMEOUT_SECONDS,
)
class CompletionTimeout:
"""Resolves HTTP timeout for ``completion()`` from model vs global settings."""
@staticmethod
def _fallback_when_no_explicit_timeout(
global_timeout: Optional[Union[float, str]],
) -> float:
"""
Used when ``model_timeout`` and kwargs timeouts are all unset.
``global_timeout`` is :attr:`litellm.request_timeout` (numeric / string), not
:class:`httpx.Timeout`.
If it equals :data:`~litellm.constants.DEFAULT_REQUEST_TIMEOUT_SECONDS` (6000),
return :data:`~litellm.constants.COMPLETION_HTTP_FALLBACK_SECONDS`. Same if
``None``. Otherwise return ``float(global_timeout)``.
"""
if global_timeout is None:
return COMPLETION_HTTP_FALLBACK_SECONDS
if float(global_timeout) == float(DEFAULT_REQUEST_TIMEOUT_SECONDS):
return COMPLETION_HTTP_FALLBACK_SECONDS
return float(global_timeout)
@staticmethod
def resolve(
model_timeout: Optional[Union[float, str, httpx.Timeout]],
kwargs: dict,
custom_llm_provider: str,
*,
global_timeout: Optional[Union[float, str]],
supports_httpx_timeout: Callable[[str], bool],
) -> Union[float, httpx.Timeout]:
"""
Resolution order (first non-None wins):
1. ``model_timeout`` (call argument / merged ``litellm_params``)
2. ``kwargs["timeout"]``
3. ``kwargs["request_timeout"]``
4. Fallback from ``global_timeout`` (:attr:`litellm.request_timeout`) if it is
the package default (6000), use 600 instead.
Coerce :class:`httpx.Timeout` when the provider does not support it.
Explicit ``6000`` on the model or in kwargs is kept as ``6000``.
"""
resolved: Union[float, str, httpx.Timeout]
if model_timeout is not None:
resolved = model_timeout
elif kwargs.get("timeout") is not None:
resolved = kwargs["timeout"]
elif kwargs.get("request_timeout") is not None:
resolved = kwargs["request_timeout"]
else:
resolved = CompletionTimeout._fallback_when_no_explicit_timeout(
global_timeout
)
if isinstance(resolved, httpx.Timeout) and not supports_httpx_timeout(
custom_llm_provider
):
read_timeout = resolved.read
resolved = (
float(read_timeout)
if read_timeout is not None
else COMPLETION_HTTP_FALLBACK_SECONDS
) # default 10 min timeout
elif not isinstance(resolved, httpx.Timeout):
resolved = float(resolved) # type: ignore
return resolved
+154 -148
View File
@@ -354,9 +354,9 @@ class Logging(LiteLLMLoggingBaseClass):
)
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[
Any
] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = (
[]
) # for generating complete stream response
self.log_raw_request_response = log_raw_request_response
# Initialize dynamic callbacks
@@ -811,9 +811,9 @@ class Logging(LiteLLMLoggingBaseClass):
prompt_spec=prompt_spec,
dynamic_callback_params=dynamic_callback_params,
):
self.model_call_details[
"prompt_integration"
] = logger.__class__.__name__
self.model_call_details["prompt_integration"] = (
logger.__class__.__name__
)
return logger
except Exception:
# If check fails, continue to next logger
@@ -881,9 +881,9 @@ class Logging(LiteLLMLoggingBaseClass):
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
non_default_params
):
self.model_call_details[
"prompt_integration"
] = anthropic_cache_control_logger.__class__.__name__
self.model_call_details["prompt_integration"] = (
anthropic_cache_control_logger.__class__.__name__
)
return anthropic_cache_control_logger
#########################################################
@@ -895,9 +895,9 @@ class Logging(LiteLLMLoggingBaseClass):
internal_usage_cache=None,
llm_router=None,
)
self.model_call_details[
"prompt_integration"
] = vector_store_custom_logger.__class__.__name__
self.model_call_details["prompt_integration"] = (
vector_store_custom_logger.__class__.__name__
)
# Add to global callbacks so post-call hooks are invoked
if (
vector_store_custom_logger
@@ -957,9 +957,9 @@ class Logging(LiteLLMLoggingBaseClass):
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
self.model_call_details["litellm_params"][
"api_base"
] = self._get_masked_api_base(additional_args.get("api_base", ""))
self.model_call_details["litellm_params"]["api_base"] = (
self._get_masked_api_base(additional_args.get("api_base", ""))
)
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
# Log the exact input to the LLM API
@@ -988,10 +988,10 @@ class Logging(LiteLLMLoggingBaseClass):
try:
# [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True:
_metadata[
"raw_request"
] = "redacted by litellm. \
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""),
@@ -1002,34 +1002,34 @@ class Logging(LiteLLMLoggingBaseClass):
_metadata["raw_request"] = str(curl_command)
# split up, so it's easier to parse in the UI
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
# NOTE: setting ignore_sensitive_headers to True will cause
# the Authorization header to be leaked when calls to the health
# endpoint are made and fail.
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
),
error=None,
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
# NOTE: setting ignore_sensitive_headers to True will cause
# the Authorization header to be leaked when calls to the health
# endpoint are made and fail.
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
),
error=None,
)
)
except Exception as e:
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
error=str(e),
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
error=str(e),
)
)
_metadata[
"raw_request"
] = "Unable to Log \
_metadata["raw_request"] = (
"Unable to Log \
raw request: {}".format(
str(e)
str(e)
)
)
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
try:
@@ -1330,13 +1330,13 @@ class Logging(LiteLLMLoggingBaseClass):
for callback in callbacks:
try:
if isinstance(callback, CustomLogger):
response: Optional[
MCPPostCallResponseObject
] = await callback.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=post_mcp_tool_call_response_obj,
start_time=start_time,
end_time=end_time,
response: Optional[MCPPostCallResponseObject] = (
await callback.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=post_mcp_tool_call_response_obj,
start_time=start_time,
end_time=end_time,
)
)
######################################################################
# if any of the callbacks modify the response, use the modified response
@@ -1543,9 +1543,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
return None
try:
@@ -1571,9 +1571,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
return None
@@ -1722,9 +1722,9 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["litellm_params"].setdefault("metadata", {})
if self.model_call_details["litellm_params"]["metadata"] is None:
self.model_call_details["litellm_params"]["metadata"] = {}
self.model_call_details["litellm_params"]["metadata"][
"hidden_params"
] = getattr(logging_result, "_hidden_params", {})
self.model_call_details["litellm_params"]["metadata"]["hidden_params"] = (
getattr(logging_result, "_hidden_params", {})
)
def _process_hidden_params_and_response_cost(
self,
@@ -1753,9 +1753,9 @@ class Logging(LiteLLMLoggingBaseClass):
result=logging_result
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(logging_result, start_time, end_time)
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(logging_result, start_time, end_time)
)
if (
standard_logging_payload := self.model_call_details.get(
@@ -1833,9 +1833,9 @@ class Logging(LiteLLMLoggingBaseClass):
end_time = datetime.datetime.now()
if self.completion_start_time is None:
self.completion_start_time = end_time
self.model_call_details[
"completion_start_time"
] = self.completion_start_time
self.model_call_details["completion_start_time"] = (
self.completion_start_time
)
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
@@ -1872,10 +1872,10 @@ class Logging(LiteLLMLoggingBaseClass):
end_time=end_time,
)
elif isinstance(result, dict) or isinstance(result, list):
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
result, start_time, end_time
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
result, start_time, end_time
)
)
if (
standard_logging_payload := self.model_call_details.get(
@@ -1884,9 +1884,9 @@ class Logging(LiteLLMLoggingBaseClass):
) is not None:
emit_standard_logging_payload(standard_logging_payload)
elif standard_logging_object is not None:
self.model_call_details[
"standard_logging_object"
] = standard_logging_object
self.model_call_details["standard_logging_object"] = (
standard_logging_object
)
else:
self.model_call_details["response_cost"] = None
@@ -2044,20 +2044,20 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=complete_streaming_response)
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=complete_streaming_response)
)
self._merge_hidden_params_from_response_into_metadata(
complete_streaming_response
)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
)
if (
standard_logging_payload := self.model_call_details.get(
@@ -2391,10 +2391,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
)
result = self.model_call_details["complete_response"]
openMeterLogger.log_success_event(
@@ -2418,10 +2418,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
)
result = self.model_call_details["complete_response"]
@@ -2560,9 +2560,9 @@ class Logging(LiteLLMLoggingBaseClass):
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details[
"async_complete_streaming_response"
] = complete_streaming_response
self.model_call_details["async_complete_streaming_response"] = (
complete_streaming_response
)
try:
if self.model_call_details.get("cache_hit", False) is True:
@@ -2573,10 +2573,10 @@ class Logging(LiteLLMLoggingBaseClass):
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(
result=complete_streaming_response
self.model_call_details["response_cost"] = (
self._response_cost_calculator(
result=complete_streaming_response
)
)
verbose_logger.debug(
@@ -2593,10 +2593,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
)
# print standard logging payload
@@ -2623,9 +2623,9 @@ class Logging(LiteLLMLoggingBaseClass):
# _success_handler_helper_fn
if self.model_call_details.get("standard_logging_object") is None:
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(result, start_time, end_time)
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(result, start_time, end_time)
)
# print standard logging payload
if (
@@ -2872,18 +2872,18 @@ class Logging(LiteLLMLoggingBaseClass):
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=_redact_string(str(exception)),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=_redact_string(str(exception)),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
)
return start_time, end_time
@@ -3853,9 +3853,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
service_name=arize_config.project_name,
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
)
for callback in _in_memory_loggers:
if (
isinstance(callback, ArizeLogger)
@@ -3881,13 +3881,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
# Add openinference.project.name attribute
if existing_attrs:
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
)
else:
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"openinference.project.name={arize_phoenix_config.project_name}"
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"openinference.project.name={arize_phoenix_config.project_name}"
)
# Set Phoenix project name from environment variable
phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None)
@@ -3895,19 +3895,19 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
# Add openinference.project.name attribute
if existing_attrs:
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"{existing_attrs},openinference.project.name={phoenix_project_name}"
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"{existing_attrs},openinference.project.name={phoenix_project_name}"
)
else:
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"openinference.project.name={phoenix_project_name}"
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"openinference.project.name={phoenix_project_name}"
)
# auth can be disabled on local deployments of arize phoenix
if arize_phoenix_config.otlp_auth_headers is not None:
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = arize_phoenix_config.otlp_auth_headers
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
arize_phoenix_config.otlp_auth_headers
)
for callback in _in_memory_loggers:
if (
@@ -4094,9 +4094,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
exporter="otlp_http",
endpoint="https://langtrace.ai/api/trace",
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
)
for callback in _in_memory_loggers:
if (
isinstance(callback, OpenTelemetry)
@@ -4991,16 +4991,22 @@ class StandardLoggingPayloadSetup:
additional_logging_headers: StandardLoggingAdditionalHeaders = {}
# Populate well-known typed fields with int/str coercion where needed
typed_keys: dict = {}
for key in StandardLoggingAdditionalHeaders.__annotations__.keys():
_key = key.lower()
_key = _key.replace("_", "-")
_key = key.lower().replace("_", "-")
typed_keys[_key] = key
if _key in additiona_headers:
try:
additional_logging_headers[key] = int(additiona_headers[_key]) # type: ignore
except (ValueError, TypeError):
verbose_logger.debug(
f"Could not convert {additiona_headers[_key]} to int for key {key}."
)
additional_logging_headers[key] = additiona_headers[_key] # type: ignore
# Preserve all remaining headers verbatim (e.g. llm_provider-x-request-id)
for k, v in additiona_headers.items():
if k.lower() not in typed_keys:
additional_logging_headers[k] = v # type: ignore
return additional_logging_headers
@staticmethod
@@ -5022,10 +5028,10 @@ class StandardLoggingPayloadSetup:
for key in StandardLoggingHiddenParams.__annotations__.keys():
if key in hidden_params:
if key == "additional_headers":
clean_hidden_params[
"additional_headers"
] = StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
clean_hidden_params["additional_headers"] = (
StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
)
else:
clean_hidden_params[key] = hidden_params[key] # type: ignore
@@ -5666,9 +5672,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
):
for k, v in metadata["user_api_key_metadata"].items():
if k == "logging": # prevent logging user logging keys
cleaned_user_api_key_metadata[
k
] = "scrubbed_by_litellm_for_sensitive_keys"
cleaned_user_api_key_metadata[k] = (
"scrubbed_by_litellm_for_sensitive_keys"
)
else:
cleaned_user_api_key_metadata[k] = v
@@ -11,6 +11,10 @@ from openai.types.completion_create_params import (
CompletionCreateParamsStreaming as TextCompletionCreateParamsStreaming,
)
from openai.types.embedding_create_params import EmbeddingCreateParams
from openai.types.responses.response_create_params import (
ResponseCreateParamsNonStreaming,
ResponseCreateParamsStreaming,
)
from litellm._logging import verbose_logger
from litellm.types.rerank import RerankRequest
@@ -65,6 +69,9 @@ class ModelParamHelper:
ModelParamHelper._get_litellm_supported_transcription_kwargs()
)
rerank_kwargs = ModelParamHelper._get_litellm_supported_rerank_kwargs()
responses_api_kwargs = (
ModelParamHelper._get_litellm_supported_responses_api_kwargs()
)
exclude_kwargs = ModelParamHelper._get_exclude_kwargs()
combined_kwargs = chat_completion_kwargs.union(
@@ -72,6 +79,7 @@ class ModelParamHelper:
embedding_kwargs,
transcription_kwargs,
rerank_kwargs,
responses_api_kwargs,
)
combined_kwargs = combined_kwargs.difference(exclude_kwargs)
return combined_kwargs
@@ -93,9 +101,9 @@ class ModelParamHelper:
streaming_params: Set[str] = set(
getattr(CompletionCreateParamsStreaming, "__annotations__", {}).keys()
)
litellm_provider_specific_params: Set[
str
] = ModelParamHelper.get_litellm_provider_specific_params_for_chat_params()
litellm_provider_specific_params: Set[str] = (
ModelParamHelper.get_litellm_provider_specific_params_for_chat_params()
)
all_chat_completion_kwargs: Set[str] = non_streaming_params.union(
streaming_params
).union(litellm_provider_specific_params)
@@ -167,6 +175,21 @@ class ModelParamHelper:
verbose_logger.debug("Error getting transcription kwargs %s", str(e))
return set()
@staticmethod
def _get_litellm_supported_responses_api_kwargs() -> Set[str]:
"""
Get the litellm supported responses API kwargs
This follows the OpenAI API Spec
"""
non_streaming_params: Set[str] = set(
getattr(ResponseCreateParamsNonStreaming, "__annotations__", {}).keys()
)
streaming_params: Set[str] = set(
getattr(ResponseCreateParamsStreaming, "__annotations__", {}).keys()
)
return non_streaming_params.union(streaming_params)
@staticmethod
def _get_exclude_kwargs() -> Set[str]:
"""
+64
View File
@@ -2,6 +2,7 @@
This file contains common utils for anthropic calls.
"""
import copy
from typing import Any, Dict, List, Optional, Union
import httpx
@@ -736,6 +737,69 @@ def strip_advisor_blocks_from_messages(
return messages
def is_anthropic_invalid_thinking_signature_error(error_text: str) -> bool:
"""
Detect Anthropic 400 when encrypted thinking signatures in history do not match
the current deployment (e.g. user rotated API key or switched model endpoint).
Example API message:
messages.N.content.M: Invalid `signature` in `thinking` block
"""
if not error_text:
return False
lower = error_text.lower()
return (
"invalid" in lower
and "signature" in lower
and "thinking" in lower
and "block" in lower
)
def strip_thinking_blocks_from_anthropic_messages(messages: List[Any]) -> List[Any]:
"""
Return a new message list with thinking / redacted_thinking content blocks removed
from each message. Used to recover from invalid thinking signatures on retry.
Messages whose content is a list and becomes empty after stripping are omitted,
since Anthropic rejects empty content arrays.
"""
out: List[Any] = []
for m in messages:
if not isinstance(m, dict):
out.append(m)
continue
mm = copy.deepcopy(m)
content = mm.get("content")
if isinstance(content, list):
filtered = [
b
for b in content
if not (
isinstance(b, dict)
and b.get("type") in ("thinking", "redacted_thinking")
)
]
if not filtered:
continue
mm["content"] = filtered
out.append(mm)
return out
def strip_thinking_blocks_from_anthropic_messages_request_dict(
data: Dict[str, Any],
) -> None:
"""
Mutate an Anthropic Messages-style request dict: strip thinking blocks from
``messages`` and remove the top-level ``thinking`` extended-thinking param.
"""
msgs = data.get("messages")
if isinstance(msgs, list):
data["messages"] = strip_thinking_blocks_from_anthropic_messages(msgs)
data.pop("thinking", None)
def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "anthropic-ratelimit-requests-limit" in headers:
@@ -27,6 +27,7 @@ class BaseAnthropicMessagesStreamingIterator:
self.litellm_logging_obj = litellm_logging_obj
self.request_body = request_body
self.start_time = datetime.now()
self.completion_start_time: datetime | None = None
async def _handle_streaming_logging(self, collected_chunks: List[bytes]):
"""Handle the logging after all chunks have been collected."""
@@ -35,6 +36,15 @@ class BaseAnthropicMessagesStreamingIterator:
)
end_time = datetime.now()
# Set completion_start_time so TTFT is calculated from the first
# chunk rather than falling back to end_time in async_success_handler.
if self.completion_start_time is not None:
self.litellm_logging_obj.completion_start_time = (
self.completion_start_time
)
self.litellm_logging_obj.model_call_details[
"completion_start_time"
] = self.completion_start_time
asyncio.create_task(
PassThroughStreamingHandler._route_streaming_logging_to_handler(
litellm_logging_obj=self.litellm_logging_obj,
@@ -100,6 +110,8 @@ class BaseAnthropicMessagesStreamingIterator:
collected_chunks = []
async for chunk in completion_stream:
if self.completion_start_time is None:
self.completion_start_time = datetime.now()
encoded_chunk = self._convert_chunk_to_sse_format(chunk)
collected_chunks.append(encoded_chunk)
yield encoded_chunk
@@ -1,7 +1,9 @@
from typing import TYPE_CHECKING, List, Optional, Tuple
import httpx
from httpx import Response
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
from litellm.secret_managers.main import get_secret_str
@@ -11,6 +13,8 @@ from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
from httpx import URL
from litellm.types.utils import CostResponseTypes
class AzurePassthroughConfig(BasePassthroughConfig):
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
@@ -83,3 +87,36 @@ class AzurePassthroughConfig(BasePassthroughConfig):
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
return super().get_models(api_key, api_base)
def logging_non_streaming_response(
self,
model: str,
custom_llm_provider: str,
httpx_response: Response,
request_data: dict,
logging_obj: Logging,
endpoint: str,
) -> Optional["CostResponseTypes"]:
from litellm import encoding
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.utils import ModelResponse
if "chat/completions" not in endpoint:
return None
openai_chat_config = OpenAIGPTConfig()
litellm_model_response: ModelResponse = openai_chat_config.transform_response(
model=model,
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
raw_response=httpx_response,
model_response=ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
request_data=request_data,
encoding=encoding,
)
return litellm_model_response
@@ -120,3 +120,46 @@ class BaseAnthropicMessagesConfig(ABC):
return BaseLLMException(
message=error_message, status_code=status_code, headers=headers
)
@property
def max_retry_on_anthropic_messages_http_error(self) -> int:
"""
Max HTTP attempts for /v1/messages when the handler may mutate the body and
retry (e.g. strip invalid encrypted thinking signatures after a deployment or
credential change).
"""
return 2
def should_retry_anthropic_messages_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool:
"""
When True, async_anthropic_messages_handler will transform the request body
and issue one more attempt (bounded by max_retry_on_anthropic_messages_http_error).
"""
from litellm.llms.anthropic.common_utils import (
is_anthropic_invalid_thinking_signature_error,
)
return (
e.response.status_code == 400
and is_anthropic_invalid_thinking_signature_error(e.response.text)
)
def transform_anthropic_messages_request_on_http_error(
self, e: httpx.HTTPStatusError, request_data: dict
) -> dict:
"""
Mutates request_data in place when retrying after a recoverable HTTP error.
"""
from litellm.llms.anthropic.common_utils import (
is_anthropic_invalid_thinking_signature_error,
strip_thinking_blocks_from_anthropic_messages_request_dict,
)
if (
e.response.status_code == 400
and is_anthropic_invalid_thinking_signature_error(e.response.text)
):
strip_thinking_blocks_from_anthropic_messages_request_dict(request_data)
return request_data
@@ -1003,7 +1003,7 @@ class AmazonConverseConfig(BaseConfig):
description=description,
)
optional_params["outputConfig"] = output_config
else:
elif json_schema is not None:
# Fallback: translate to a synthetic tool call
# https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
_tool = self._create_json_tool_call_for_response_format(
@@ -1025,6 +1025,12 @@ class AmazonConverseConfig(BaseConfig):
)
if non_default_params.get("stream", False) is True:
optional_params["fake_stream"] = True
# else: response_format=json_object with no schema.
# Don't inject the synthetic json_tool_call tool here. When no
# schema is given, _create_json_tool_call_for_response_format
# produces an empty schema (properties: {}), and the model
# returns {} instead of the requested JSON. The model already
# returns JSON when the prompt asks for it.
optional_params["json_mode"] = True
return optional_params
@@ -2030,6 +2036,12 @@ class AmazonConverseConfig(BaseConfig):
_message = Message(**chat_completion_message)
initial_finish_reason = map_finish_reason(completion_response["stopReason"])
# When json_mode filtered out all synthetic tool calls the response
# is plain content, not a pending tool invocation. Fix finish_reason
# so callers (e.g. OpenAI SDK) don't misinterpret it.
if json_mode and not filtered_tools and tools:
initial_finish_reason = "stop"
(
returned_message,
returned_finish_reason,
+8 -3
View File
@@ -30,7 +30,9 @@ from litellm.constants import (
AIOHTTP_KEEPALIVE_TIMEOUT,
AIOHTTP_NEEDS_CLEANUP_CLOSED,
AIOHTTP_TTL_DNS_CACHE,
COMPLETION_HTTP_FALLBACK_SECONDS,
DEFAULT_SSL_CIPHERS,
HTTP_HANDLER_CONNECT_TIMEOUT_SECONDS,
)
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.types.llms.custom_http import *
@@ -70,7 +72,10 @@ def get_default_headers() -> dict:
headers = get_default_headers()
# https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
_DEFAULT_TIMEOUT = httpx.Timeout(
timeout=COMPLETION_HTTP_FALLBACK_SECONDS,
connect=HTTP_HANDLER_CONNECT_TIMEOUT_SECONDS,
)
def _prepare_request_data_and_content(
@@ -1258,7 +1263,7 @@ def get_async_httpx_client(
_new_client = AsyncHTTPHandler(**handler_params)
else:
_new_client = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0),
timeout=_DEFAULT_TIMEOUT,
shared_session=shared_session,
)
@@ -1307,7 +1312,7 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
}
_new_client = HTTPHandler(**handler_params)
else:
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
_new_client = HTTPHandler(timeout=_DEFAULT_TIMEOUT)
cache.set_cache(
key=_cache_key_name,
+98 -22
View File
@@ -1816,6 +1816,73 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
)
async def _async_post_anthropic_messages_with_http_error_retry(
self,
async_httpx_client: AsyncHTTPHandler,
request_url: str,
headers: dict,
signed_json_body: Optional[bytes],
request_body: dict,
stream: bool,
logging_obj: LiteLLMLoggingObj,
provider_config: BaseAnthropicMessagesConfig,
litellm_params: GenericLiteLLMParams,
api_key: Optional[str],
model: str,
) -> httpx.Response:
max_attempts = max(
provider_config.max_retry_on_anthropic_messages_http_error, 1
)
litellm_params_dict = dict(litellm_params)
optional_params_dict = dict(litellm_params)
for attempt_idx in range(max_attempts):
try:
response = await async_httpx_client.post(
url=request_url,
headers=headers,
data=signed_json_body or json.dumps(request_body),
stream=stream or False,
logging_obj=logging_obj,
)
response.raise_for_status()
return response
except httpx.HTTPStatusError as e:
hit_max_attempt = attempt_idx + 1 == max_attempts
should_retry = (
provider_config.should_retry_anthropic_messages_on_http_error(
e=e, litellm_params=litellm_params_dict
)
)
if should_retry and not hit_max_attempt:
verbose_logger.debug(
"Anthropic /v1/messages: invalid thinking signature; "
"stripping thinking blocks and retrying (attempt %s/%s).",
attempt_idx + 2,
max_attempts,
)
provider_config.transform_anthropic_messages_request_on_http_error(
e=e, request_data=request_body
)
headers, signed_json_body = provider_config.sign_request(
headers=headers,
optional_params=optional_params_dict,
request_data=request_body,
api_base=request_url,
api_key=api_key,
stream=stream,
fake_stream=False,
model=model,
)
logging_obj.model_call_details.update(request_body)
continue
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
raise RuntimeError(
"unreachable: anthropic messages HTTP retry loop exited without return"
)
async def async_anthropic_messages_handler(
self,
model: str,
@@ -1955,19 +2022,19 @@ class BaseLLMHTTPHandler:
},
)
try:
response = await async_httpx_client.post(
url=request_url,
headers=headers,
data=signed_json_body or json.dumps(request_body),
stream=stream or False,
logging_obj=logging_obj,
)
response.raise_for_status()
except Exception as e:
raise self._handle_error(
e=e, provider_config=anthropic_messages_provider_config
)
response = await self._async_post_anthropic_messages_with_http_error_retry(
async_httpx_client=async_httpx_client,
request_url=request_url,
headers=headers,
signed_json_body=signed_json_body,
request_body=request_body,
stream=stream or False,
logging_obj=logging_obj,
provider_config=anthropic_messages_provider_config,
litellm_params=litellm_params,
api_key=api_key,
model=model,
)
# used for logging + cost tracking
logging_obj.model_call_details["httpx_response"] = response
@@ -4496,9 +4563,9 @@ class BaseLLMHTTPHandler:
# Second: Execute agentic loop
# Add custom_llm_provider to kwargs so the agentic loop can reconstruct the full model name
kwargs_with_provider = kwargs.copy() if kwargs else {}
kwargs_with_provider[
"custom_llm_provider"
] = custom_llm_provider
kwargs_with_provider["custom_llm_provider"] = (
custom_llm_provider
)
agentic_response = await callback.async_run_agentic_loop(
tools=tool_calls,
model=model,
@@ -4614,9 +4681,9 @@ class BaseLLMHTTPHandler:
# Second: Execute agentic loop
# Add custom_llm_provider to kwargs so the agentic loop can reconstruct the full model name
kwargs_with_provider = kwargs.copy() if kwargs else {}
kwargs_with_provider[
"custom_llm_provider"
] = custom_llm_provider
kwargs_with_provider["custom_llm_provider"] = (
custom_llm_provider
)
agentic_response = (
await callback.async_run_chat_completion_agentic_loop(
tools=tool_calls,
@@ -5110,7 +5177,10 @@ class BaseLLMHTTPHandler:
_is_async: bool = False,
fake_stream: bool = False,
litellm_metadata: Optional[Dict[str, Any]] = None,
) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse],]:
) -> Union[
ImageResponse,
Coroutine[Any, Any, ImageResponse],
]:
"""
Handles image edit requests.
@@ -5322,7 +5392,10 @@ class BaseLLMHTTPHandler:
fake_stream: bool = False,
litellm_metadata: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse],]:
) -> Union[
ImageResponse,
Coroutine[Any, Any, ImageResponse],
]:
"""
Handles image generation requests.
When _is_async=True, returns a coroutine instead of making the call directly.
@@ -5562,7 +5635,10 @@ class BaseLLMHTTPHandler:
fake_stream: bool = False,
litellm_metadata: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
) -> Union[VideoObject, Coroutine[Any, Any, VideoObject],]:
) -> Union[
VideoObject,
Coroutine[Any, Any, VideoObject],
]:
"""
Handles video generation requests.
When _is_async=True, returns a coroutine instead of making the call directly.
+4 -2
View File
@@ -16,6 +16,7 @@ from httpx._models import Headers, Response
from pydantic import BaseModel
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.prompt_templates.common_utils import (
_extract_reasoning_content,
convert_content_list_to_str,
@@ -349,7 +350,8 @@ class OllamaChatConfig(BaseConfig):
response_json = raw_response.json()
## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop"
_done_reason = map_finish_reason(response_json.get("done_reason") or "stop")
model_response.choices[0].finish_reason = _done_reason
response_json_message = response_json.get("message")
if response_json_message is not None:
if "thinking" in response_json_message:
@@ -535,7 +537,7 @@ class OllamaChatCompletionResponseIterator(BaseModelResponseIterator):
)
if chunk["done"] is True:
finish_reason = chunk.get("done_reason", "stop")
finish_reason = chunk.get("done_reason") or "stop"
# Override finish_reason when tool_calls are present
# Fixes: https://github.com/BerriAI/litellm/issues/18922
if tool_calls is not None:
@@ -78,6 +78,19 @@ class VertexAIPartnerModelsTokenCounter(VertexBase):
return endpoint
@staticmethod
def _strip_version_suffix(model: str) -> str:
"""
Strip version suffixes (e.g. @default, @20251001) from model names.
The Vertex AI count-tokens endpoint rejects model names that include
version suffixes for example, "claude-sonnet-4-6@default" returns
"not supported for token counting" while "claude-sonnet-4-6" works.
"""
if "@" in model:
return model.split("@")[0]
return model
async def handle_count_tokens_request(
self,
model: str,
@@ -98,6 +111,15 @@ class VertexAIPartnerModelsTokenCounter(VertexBase):
Raises:
ValueError: If required parameters are missing or invalid
"""
# Strip version suffixes (@default, @20251001, etc.) — the Vertex AI
# count-tokens endpoint does not accept versioned model names.
model = self._strip_version_suffix(model)
if "model" in request_data:
request_data = {
**request_data,
"model": self._strip_version_suffix(request_data["model"]),
}
# Validate request
if "messages" not in request_data:
raise ValueError("messages required for token counting")
+8 -8
View File
@@ -77,6 +77,7 @@ from litellm.litellm_core_utils.audio_utils.utils import (
calculate_request_duration,
get_audio_file_for_health_check,
)
from litellm.litellm_core_utils.completion_timeout import CompletionTimeout
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.litellm_core_utils.get_provider_specific_headers import (
ProviderSpecificHeaderUtils,
@@ -1401,14 +1402,13 @@ def completion( # type: ignore # noqa: PLR0915
) # support region-based pricing for bedrock
### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout(
custom_llm_provider
):
timeout = timeout.read or 600 # default 10 min timeout
elif not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
timeout = CompletionTimeout.resolve(
timeout,
kwargs,
custom_llm_provider,
global_timeout=getattr(litellm, "request_timeout", None),
supports_httpx_timeout=supports_httpx_timeout,
)
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if (
@@ -25179,6 +25179,58 @@
"supports_web_search": true,
"tpm": 800000
},
"openrouter/google/gemini-3.1-flash-lite-preview": {
"cache_read_input_token_cost": 2.5e-08,
"cache_read_input_token_cost_per_audio_token": 5e-08,
"input_cost_per_audio_token": 5e-07,
"input_cost_per_token": 2.5e-07,
"litellm_provider": "openrouter",
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_images_per_prompt": 3000,
"max_input_tokens": 1048576,
"max_output_tokens": 65536,
"max_pdf_size_mb": 30,
"max_tokens": 65536,
"max_video_length": 1,
"max_videos_per_prompt": 10,
"mode": "chat",
"output_cost_per_reasoning_token": 1.5e-06,
"output_cost_per_token": 1.5e-06,
"rpm": 2000,
"source": "https://ai.google.dev/pricing/gemini-3",
"supported_endpoints": [
"/v1/chat/completions",
"/v1/completions",
"/v1/batch"
],
"supported_modalities": [
"text",
"image",
"audio",
"video"
],
"supported_output_modalities": [
"text"
],
"supports_audio_input": true,
"supports_audio_output": false,
"supports_code_execution": true,
"supports_file_search": true,
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_pdf_input": true,
"supports_prompt_caching": true,
"supports_reasoning": true,
"supports_response_schema": true,
"supports_system_messages": true,
"supports_tool_choice": true,
"supports_url_context": true,
"supports_video_input": true,
"supports_vision": true,
"supports_web_search": true,
"tpm": 800000
},
"openrouter/google/gemini-3.1-pro-preview": {
"cache_read_input_token_cost": 2e-07,
"cache_read_input_token_cost_above_200k_tokens": 4e-07,
@@ -14,3 +14,8 @@ from typing import Optional
_mcp_active_toolset_id: ContextVar[Optional[str]] = ContextVar(
"_mcp_active_toolset_id", default=None
)
# Per-request merged InitializeResult.instructions; set in MCP HTTP/SSE handlers.
_mcp_gateway_initialize_instructions: ContextVar[Optional[str]] = ContextVar(
"_mcp_gateway_initialize_instructions", default=None
)
@@ -184,6 +184,16 @@ class MCPServerManager:
"gmail_send_email": "zapier_mcp_server",
}
"""
self._upstream_initialize_instructions_by_server_id: Dict[str, str] = {}
def _remember_upstream_initialize_instructions(
self, server: MCPServer, client: MCPClient
) -> None:
raw = getattr(client, "_last_initialize_instructions", None)
if raw and str(raw).strip():
self._upstream_initialize_instructions_by_server_id[server.server_id] = str(
raw
).strip()
def get_registry(self) -> Dict[str, MCPServer]:
"""
@@ -204,6 +214,7 @@ class MCPServerManager:
mcp_aliases: Optional dictionary mapping aliases to server names from litellm_settings
"""
verbose_logger.debug("Loading MCP Servers from config-----")
self._upstream_initialize_instructions_by_server_id.clear()
# Track which aliases have been used to ensure only first occurrence is used
used_aliases = set()
@@ -351,6 +362,7 @@ class MCPServerManager:
aws_service_name=server_config.get("aws_service_name", None),
aws_role_name=server_config.get("aws_role_name", None),
aws_session_name=server_config.get("aws_session_name", None),
instructions=server_config.get("instructions", None),
)
self.config_mcp_servers[server_id] = new_server
@@ -693,6 +705,7 @@ class MCPServerManager:
aws_service_name=aws_creds.get("aws_service_name"),
aws_role_name=aws_creds.get("aws_role_name"),
aws_session_name=aws_creds.get("aws_session_name"),
instructions=mcp_server.instructions,
)
return new_server
@@ -1247,6 +1260,7 @@ class MCPServerManager:
return tools
else:
tools = await self._fetch_tools_with_timeout(client, server.name)
self._remember_upstream_initialize_instructions(server, client)
prefixed_or_original_tools = self._create_prefixed_tools(
tools, server, add_prefix=add_prefix
@@ -2383,6 +2397,7 @@ class MCPServerManager:
# If proxy_logging_obj is not None, the tool call result is at index 1 (after the during hook task)
result_index = 1 if proxy_logging_obj else 0
result = mcp_responses[result_index]
self._remember_upstream_initialize_instructions(mcp_server, client)
return cast(CallToolResult, result)
@@ -2627,6 +2642,7 @@ class MCPServerManager:
)
verbose_logger.debug("Loading MCP servers from database into registry...")
self._upstream_initialize_instructions_by_server_id.clear()
# perform authz check to filter the mcp servers user has access to
prisma_client = get_prisma_client_or_throw(
@@ -2910,6 +2926,7 @@ class MCPServerManager:
await asyncio.wait_for(
client.run_with_session(_noop), timeout=MCP_HEALTH_CHECK_TIMEOUT
)
self._remember_upstream_initialize_instructions(server, client)
status = "healthy"
except asyncio.TimeoutError:
health_check_error = (
@@ -2951,6 +2968,7 @@ class MCPServerManager:
token_url=server.token_url,
registration_url=server.registration_url,
allow_all_keys=server.allow_all_keys,
instructions=server.instructions,
)
async def get_all_mcp_servers_with_health_and_teams(
@@ -3046,6 +3064,7 @@ class MCPServerManager:
is_byok=server.is_byok,
byok_description=server.byok_description,
byok_api_key_help_url=server.byok_api_key_help_url,
instructions=server.instructions,
)
async def get_all_mcp_servers_unfiltered(self) -> List[LiteLLM_MCPServerTable]:
@@ -933,6 +933,7 @@ if MCP_AVAILABLE:
authorization_url=request.authorization_url,
registration_url=request.registration_url,
oauth2_flow=_oauth2_flow,
instructions=request.instructions,
)
stdio_env = global_mcp_server_manager._build_stdio_env(
@@ -7,6 +7,7 @@ LiteLLM MCP Server Routes
import asyncio
import contextlib
import time
import types
import traceback
import uuid
from datetime import datetime
@@ -37,7 +38,10 @@ from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
get_request_base_url,
)
from litellm.proxy._experimental.mcp_server.mcp_context import _mcp_active_toolset_id
from litellm.proxy._experimental.mcp_server.mcp_context import (
_mcp_active_toolset_id,
_mcp_gateway_initialize_instructions,
)
from litellm.proxy._experimental.mcp_server.mcp_debug import MCPDebug
from litellm.proxy._experimental.mcp_server.utils import (
LITELLM_MCP_SERVER_DESCRIPTION,
@@ -122,6 +126,8 @@ _INITIALIZATION_LOCK = asyncio.Lock()
if MCP_AVAILABLE:
from mcp.server import Server
from mcp.server.lowlevel.server import NotificationOptions
from mcp.server.models import InitializationOptions
# Import auth context variables and middleware
from mcp.server.auth.middleware.auth_context import (
@@ -200,6 +206,21 @@ if MCP_AVAILABLE:
)
return normalized
def _gateway_create_initialization_options(
self,
notification_options: Optional[NotificationOptions] = None,
experimental_capabilities: Optional[Dict[str, Dict[str, Any]]] = None,
) -> InitializationOptions:
opts = Server.create_initialization_options(
self,
notification_options=notification_options,
experimental_capabilities=experimental_capabilities or {},
)
merged = _mcp_gateway_initialize_instructions.get()
if merged is not None:
return opts.model_copy(update={"instructions": merged})
return opts
########################################################
############ Initialize the MCP Server #################
########################################################
@@ -207,6 +228,9 @@ if MCP_AVAILABLE:
name=LITELLM_MCP_SERVER_NAME,
version=LITELLM_MCP_SERVER_VERSION,
)
server.create_initialization_options = types.MethodType( # type: ignore[method-assign]
_gateway_create_initialization_options, server
)
sse: SseServerTransport = SseServerTransport("/mcp/sse/messages")
# Create session managers
@@ -1021,7 +1045,9 @@ if MCP_AVAILABLE:
except (ValueError, TypeError):
pass
ttl = _compute_per_user_token_ttl(server, raw_expires)
await mcp_per_user_token_cache.set(user_id, server_id, access_token, ttl)
await mcp_per_user_token_cache.set(
user_id, server_id, access_token, ttl
)
return {"Authorization": f"Bearer {access_token}"}
except Exception as e:
@@ -1103,6 +1129,57 @@ if MCP_AVAILABLE:
return server_auth_header, extra_headers
def _merge_gateway_initialize_instructions(
allowed_mcp_servers: List[MCPServer],
) -> Optional[str]:
"""YAML/DB override, else in-memory upstream text from list_tools / health_check / call_tool."""
if not allowed_mcp_servers:
return None
texts: List[Tuple[str, str]] = []
for server in allowed_mcp_servers:
label = (
server.alias
or server.server_name
or server.name
or server.server_id
or "mcp"
)
if server.instructions and server.instructions.strip():
texts.append((label, server.instructions.strip()))
continue
if server.spec_path:
continue
cached = global_mcp_server_manager._upstream_initialize_instructions_by_server_id.get(
server.server_id
)
if cached and cached.strip():
texts.append((label, cached.strip()))
if not texts:
return None
if len(texts) == 1:
return texts[0][1]
return "\n\n---\n\n".join(f"[{lbl}]\n{txt}" for lbl, txt in texts)
@contextlib.asynccontextmanager
async def _gateway_initialize_instructions_request_scope(
user_api_key_auth: Optional[UserAPIKeyAuth],
mcp_servers: Optional[List[str]],
client_ip: Optional[str],
) -> AsyncIterator[None]:
allowed = await _get_allowed_mcp_servers(
user_api_key_auth=user_api_key_auth,
mcp_servers=mcp_servers,
client_ip=client_ip,
)
merged = _merge_gateway_initialize_instructions(allowed_mcp_servers=allowed)
tok = _mcp_gateway_initialize_instructions.set(merged)
try:
yield
finally:
_mcp_gateway_initialize_instructions.reset(tok)
async def _get_tools_from_mcp_servers( # noqa: PLR0915
user_api_key_auth: Optional[UserAPIKeyAuth],
mcp_auth_header: Optional[str],
@@ -2670,7 +2747,12 @@ if MCP_AVAILABLE:
# Request was fully handled (e.g., DELETE on non-existent session)
return
await session_manager.handle_request(scope, receive, send)
async with _gateway_initialize_instructions_request_scope(
user_api_key_auth,
mcp_servers,
_client_ip,
):
await session_manager.handle_request(scope, receive, send)
except HTTPException:
# Re-raise HTTP exceptions to preserve status codes and details
raise
@@ -2729,7 +2811,12 @@ if MCP_AVAILABLE:
await initialize_session_managers()
await asyncio.sleep(0.1)
await sse_session_manager.handle_request(scope, receive, send)
async with _gateway_initialize_instructions_request_scope(
user_api_key_auth,
mcp_servers,
_sse_client_ip,
):
await sse_session_manager.handle_request(scope, receive, send)
except Exception as e:
verbose_logger.exception(f"Error handling MCP request: {e}")
# Instead of re-raising, try to send a graceful error response
+51 -48
View File
@@ -904,9 +904,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
allowed_cache_controls: Optional[list] = []
config: Optional[dict] = {}
permissions: Optional[dict] = {}
model_max_budget: Optional[
dict
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_max_budget: Optional[dict] = (
{}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_config = ConfigDict(protected_namespaces=())
model_rpm_limit: Optional[dict] = None
@@ -1048,9 +1048,9 @@ class RegenerateKeyRequest(GenerateKeyRequest):
spend: Optional[float] = None
metadata: Optional[dict] = None
new_master_key: Optional[str] = None
grace_period: Optional[
str
] = None # Duration to keep old key valid (e.g. "24h", "2d"); None = immediate revoke
grace_period: Optional[str] = (
None # Duration to keep old key valid (e.g. "24h", "2d"); None = immediate revoke
)
class ResetSpendRequest(LiteLLMPydanticObjectBase):
@@ -1137,6 +1137,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
tool_name_to_description: Optional[Dict[str, str]] = None
extra_headers: Optional[List[str]] = None
static_headers: Optional[Dict[str, str]] = None
instructions: Optional[str] = None
# Stdio-specific fields
command: Optional[str] = None
args: List[str] = Field(default_factory=list)
@@ -1219,6 +1220,7 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
tool_name_to_description: Optional[Dict[str, str]] = None
extra_headers: Optional[List[str]] = None
static_headers: Optional[Dict[str, str]] = None
instructions: Optional[str] = None
# Stdio-specific fields
command: Optional[str] = None
args: List[str] = Field(default_factory=list)
@@ -1270,6 +1272,7 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase):
transport: MCPTransportType
auth_type: Optional[MCPAuthType] = None
credentials: Optional[MCPCredentials] = None
instructions: Optional[str] = None
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
@@ -1574,12 +1577,12 @@ class NewCustomerRequest(BudgetNewRequest):
blocked: bool = False # allow/disallow requests for this end-user
budget_id: Optional[str] = None # give either a budget_id or max_budget
spend: Optional[float] = None
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
@model_validator(mode="before")
@@ -1602,12 +1605,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
@@ -1697,15 +1700,15 @@ class NewTeamRequest(TeamBase):
] = None # raise an error if 'guaranteed_throughput' is set and we're overallocating tpm
model_tpm_limit: Optional[Dict[str, int]] = None
team_member_budget: Optional[
float
] = None # allow user to set a budget for all team members
team_member_rpm_limit: Optional[
int
] = None # allow user to set RPM limit for all team members
team_member_tpm_limit: Optional[
int
] = None # allow user to set TPM limit for all team members
team_member_budget: Optional[float] = (
None # allow user to set a budget for all team members
)
team_member_rpm_limit: Optional[int] = (
None # allow user to set RPM limit for all team members
)
team_member_tpm_limit: Optional[int] = (
None # allow user to set TPM limit for all team members
)
team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m"
team_member_budget_duration: Optional[str] = None # e.g. "30d", "1mo"
allowed_vector_store_indexes: Optional[List[AllowedVectorStoreIndexItem]] = None
@@ -1802,9 +1805,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
class AddTeamCallback(LiteLLMPydanticObjectBase):
callback_name: str
callback_type: Optional[
Literal["success", "failure", "success_and_failure"]
] = "success_and_failure"
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
"success_and_failure"
)
callback_vars: Dict[str, str]
@model_validator(mode="before")
@@ -2146,9 +2149,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
stored_in_db: Optional[bool]
field_default_value: Any
premium_field: bool = False
nested_fields: Optional[
List[FieldDetail]
] = None # For nested dictionary or Pydantic fields
nested_fields: Optional[List[FieldDetail]] = (
None # For nested dictionary or Pydantic fields
)
class UserHeaderMapping(LiteLLMPydanticObjectBase):
@@ -2507,9 +2510,9 @@ class UserAPIKeyAuth(
user_max_budget: Optional[float] = None
request_route: Optional[str] = None
user: Optional[Any] = None # Expanded user object when expand=user is used
created_by_user: Optional[
Any
] = None # Expanded created_by user when expand=user is used
created_by_user: Optional[Any] = (
None # Expanded created_by user when expand=user is used
)
end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
# Decoded upstream IdP claims (groups, roles, etc.) propagated by JWT auth machinery
# and forwarded into outbound tokens by guardrails such as MCPJWTSigner.
@@ -2648,9 +2651,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
budget_id: Optional[str] = None
created_at: datetime
updated_at: datetime
user: Optional[
Any
] = None # You might want to replace 'Any' with a more specific type if available
user: Optional[Any] = (
None # You might want to replace 'Any' with a more specific type if available
)
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
user_email: Optional[str] = None
@@ -3805,9 +3808,9 @@ class TeamModelDeleteRequest(BaseModel):
# Organization Member Requests
class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str
max_budget_in_organization: Optional[
float
] = None # Users max budget within the organization
max_budget_in_organization: Optional[float] = (
None # Users max budget within the organization
)
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
@@ -4062,9 +4065,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
Maps provider names to their budget configs.
"""
providers: Dict[
str, ProviderBudgetResponseObject
] = {} # Dictionary mapping provider names to their budget configurations
providers: Dict[str, ProviderBudgetResponseObject] = (
{}
) # Dictionary mapping provider names to their budget configurations
class ProxyStateVariables(TypedDict):
@@ -4226,9 +4229,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
enforce_rbac: bool = False
roles_jwt_field: Optional[str] = None # v2 on role mappings
role_mappings: Optional[List[RoleMapping]] = None
object_id_jwt_field: Optional[
str
] = None # can be either user / team, inferred from the role mapping
object_id_jwt_field: Optional[str] = (
None # can be either user / team, inferred from the role mapping
)
scope_mappings: Optional[List[ScopeMapping]] = None
enforce_scope_based_access: bool = False
enforce_team_based_model_access: bool = False
@@ -340,6 +340,35 @@ async def update_credential(
"updated_by": user_api_key_dict.user_id,
},
)
# Sync in-memory credential_list (skip if not in memory - e.g., proxy restarted)
new_name = merged_credential.credential_name
existing_in_memory: Optional[CredentialItem] = None
for cred in litellm.credential_list:
if cred.credential_name == credential_name:
existing_in_memory = cred
break
if existing_in_memory is not None:
in_memory_values = dict(existing_in_memory.credential_values or {})
if credential.credential_values:
in_memory_values.update(credential.credential_values)
in_memory_info = dict(existing_in_memory.credential_info or {})
if credential.credential_info:
in_memory_info.update(credential.credential_info)
updated_in_memory = CredentialItem(
credential_name=new_name,
credential_values=in_memory_values,
credential_info=in_memory_info,
)
# Remove old entry if renamed, then use upsert_credentials to handle duplicates
if new_name != credential_name:
litellm.credential_list = [
c for c in litellm.credential_list
if c.credential_name != credential_name
]
CredentialAccessor.upsert_credentials([updated_in_memory])
return {"success": True, "message": "Credential updated successfully"}
except Exception as e:
return handle_exception_on_proxy(e)
+57 -28
View File
@@ -4,6 +4,12 @@ from litellm import verbose_logger
_db = Any
# Markers that indicate a view/relation does not yet exist in the database.
# Keeping these in one place avoids repeating the check across all view blocks
# and prevents overly broad matches (e.g. bare 'undefined' would also match
# 'undefined function' or 'column undefined_col referenced in query').
_VIEW_NOT_FOUND_MARKERS = ("does not exist", "no such table", "undefined table")
async def create_missing_views(db: _db): # noqa: PLR0915
"""
@@ -18,14 +24,17 @@ async def create_missing_views(db: _db): # noqa: PLR0915
If the view doesn't exist, one will be created.
"""
try:
# Try to select one row from the view
await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""")
print("LiteLLM_VerificationTokenView Exists!") # noqa
except Exception:
verbose_logger.debug("LiteLLM_VerificationTokenView Exists!")
except Exception as e:
error_msg = str(e).lower()
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise
# If an error occurs, the view does not exist, so create it
await db.execute_raw(
"""
await db.execute_raw("""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
@@ -37,15 +46,17 @@ async def create_missing_views(db: _db): # noqa: PLR0915
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id
LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id;
"""
)
""")
print("LiteLLM_VerificationTokenView Created!") # noqa
verbose_logger.debug("LiteLLM_VerificationTokenView Created!")
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
print("MonthlyGlobalSpend Exists!") # noqa
except Exception:
verbose_logger.debug("MonthlyGlobalSpend Exists!")
except Exception as e:
error_msg = str(e).lower()
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
SELECT
@@ -60,12 +71,15 @@ async def create_missing_views(db: _db): # noqa: PLR0915
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpend Created!") # noqa
verbose_logger.debug("MonthlyGlobalSpend Created!")
try:
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
print("Last30dKeysBySpend Exists!") # noqa
except Exception:
verbose_logger.debug("Last30dKeysBySpend Exists!")
except Exception as e:
error_msg = str(e).lower()
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise
sql_query = """
CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
SELECT
@@ -88,12 +102,15 @@ async def create_missing_views(db: _db): # noqa: PLR0915
"""
await db.execute_raw(query=sql_query)
print("Last30dKeysBySpend Created!") # noqa
verbose_logger.debug("Last30dKeysBySpend Created!")
try:
await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
print("Last30dModelsBySpend Exists!") # noqa
except Exception:
verbose_logger.debug("Last30dModelsBySpend Exists!")
except Exception as e:
error_msg = str(e).lower()
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise
sql_query = """
CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
SELECT
@@ -111,11 +128,14 @@ async def create_missing_views(db: _db): # noqa: PLR0915
"""
await db.execute_raw(query=sql_query)
print("Last30dModelsBySpend Created!") # noqa
verbose_logger.debug("Last30dModelsBySpend Created!")
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""")
print("MonthlyGlobalSpendPerKey Exists!") # noqa
except Exception:
verbose_logger.debug("MonthlyGlobalSpendPerKey Exists!")
except Exception as e:
error_msg = str(e).lower()
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
SELECT
@@ -132,13 +152,16 @@ async def create_missing_views(db: _db): # noqa: PLR0915
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpendPerKey Created!") # noqa
verbose_logger.debug("MonthlyGlobalSpendPerKey Created!")
try:
await db.query_raw(
"""SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
)
print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
except Exception:
verbose_logger.debug("MonthlyGlobalSpendPerUserPerKey Exists!")
except Exception as e:
error_msg = str(e).lower()
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
SELECT
@@ -157,12 +180,15 @@ async def create_missing_views(db: _db): # noqa: PLR0915
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
verbose_logger.debug("MonthlyGlobalSpendPerUserPerKey Created!")
try:
await db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""")
print("DailyTagSpend Exists!") # noqa
except Exception:
verbose_logger.debug("DailyTagSpend Exists!")
except Exception as e:
error_msg = str(e).lower()
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise
sql_query = """
CREATE OR REPLACE VIEW "DailyTagSpend" AS
SELECT
@@ -175,12 +201,15 @@ async def create_missing_views(db: _db): # noqa: PLR0915
"""
await db.execute_raw(query=sql_query)
print("DailyTagSpend Created!") # noqa
verbose_logger.debug("DailyTagSpend Created!")
try:
await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""")
print("Last30dTopEndUsersSpend Exists!") # noqa
except Exception:
verbose_logger.debug("Last30dTopEndUsersSpend Exists!")
except Exception as e:
error_msg = str(e).lower()
if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS):
raise
sql_query = """
CREATE VIEW "Last30dTopEndUsersSpend" AS
SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
@@ -193,7 +222,7 @@ async def create_missing_views(db: _db): # noqa: PLR0915
"""
await db.execute_raw(query=sql_query)
print("Last30dTopEndUsersSpend Created!") # noqa
verbose_logger.debug("Last30dTopEndUsersSpend Created!")
return
@@ -21,9 +21,18 @@ class PodLockManager:
Ensures that only one pod can run a cron job at a time.
"""
_COMPARE_AND_DELETE_LOCK_SCRIPT = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
def __init__(self, redis_cache: Optional[RedisCache] = None):
self.pod_id = str(uuid.uuid4())
self.redis_cache = redis_cache
self._release_lock_script: Optional[Any] = None
@staticmethod
def get_redis_lock_key(cronjob_id: str) -> str:
@@ -107,53 +116,35 @@ class PodLockManager:
):
"""
Release the lock if the current pod holds it.
Uses get and delete commands to ensure that only the owner can release the lock.
Uses an atomic Lua compare-and-delete to prevent TOCTOU races where a
stale owner could delete a newly reacquired lock.
Falls back to GET + DEL for cache implementations that don't support
script registration.
"""
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
return
try:
cronjob_id = cronjob_id
verbose_proxy_logger.debug(
"Pod %s attempting to release Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
result = await self.redis_cache.async_delete_cache(lock_key)
if result == 1:
verbose_proxy_logger.info(
"Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
self._emit_released_lock_event(
cronjob_id=cronjob_id,
pod_id=self.pod_id,
)
else:
verbose_proxy_logger.warning(
"Pod %s failed to release Redis lock for cronjob_id=%s. "
"Lock will expire after its TTL.",
self.pod_id,
cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
self.pod_id,
cronjob_id,
current_value,
)
result = await self._compare_and_delete_lock(lock_key=lock_key)
if result == 1:
verbose_proxy_logger.info(
"Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
self._emit_released_lock_event(
cronjob_id=cronjob_id,
pod_id=self.pod_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
"Pod %s failed to release Redis lock for cronjob_id=%s (lock missing or held by another pod)",
self.pod_id,
cronjob_id,
)
@@ -162,6 +153,42 @@ class PodLockManager:
f"Error releasing Redis lock for {cronjob_id}: {e}"
)
async def _compare_and_delete_lock(self, lock_key: str) -> int:
"""
Atomically delete lock key only if current pod owns it.
Falls back to get/delete for non-RedisCache implementations that do not
expose Lua script registration.
"""
script_register = getattr(self.redis_cache, "async_register_script", None)
if callable(script_register):
try:
if self._release_lock_script is None:
self._release_lock_script = script_register(
self._COMPARE_AND_DELETE_LOCK_SCRIPT
)
result = await self._release_lock_script(
keys=[lock_key], args=[self.pod_id]
)
return int(result or 0)
except Exception:
# Lua execution failed (e.g. Redis restart cleared loaded scripts,
# or scripting is disabled). Reset cached script handle and fall
# through to the GET + DEL fallback so the lock is still released.
self._release_lock_script = None
verbose_proxy_logger.warning(
"Lua compare-and-delete failed for lock_key=%s, falling back to GET+DEL",
lock_key,
)
current_value = await self.redis_cache.async_get_cache(lock_key) # type: ignore
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value != self.pod_id:
return 0
result = await self.redis_cache.async_delete_cache(lock_key) # type: ignore
return int(result or 0)
@staticmethod
def _emit_acquired_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
@@ -82,7 +82,7 @@ class SpendLogCleanup:
break
# Step 1: Find logs and delete them in one go without fetching to application
# Delete in batches, limited by self.batch_size
deleted_count = await prisma_client.db.execute_raw(
deleted_result = await prisma_client.db.execute_raw(
"""
DELETE FROM "LiteLLM_SpendLogs"
WHERE "request_id" IN (
@@ -94,6 +94,17 @@ class SpendLogCleanup:
cutoff_date,
self.batch_size,
)
deleted_count = 0
if isinstance(deleted_result, int):
deleted_count = deleted_result
else:
verbose_proxy_logger.error(
f"Unexpected execute_raw return type for spend log cleanup: {type(deleted_result)}; "
"aborting cleanup to avoid infinite loop"
)
break
verbose_proxy_logger.info(f"Deleted {deleted_count} logs in this batch")
if deleted_count == 0:
@@ -83,7 +83,12 @@ def _redact_pii_matches(response_json: dict) -> dict:
redacted_response = copy.deepcopy(response_json)
# Get assessments from the response
assessments = redacted_response.get("assessments", [])
# NOTE: We use `.get("key") or []` instead of `.get("key", [])` because
# the Bedrock API can return explicit `null` for list fields (e.g. "regexes": null).
# In Python, dict.get("key", []) returns None (not []) when the key exists
# with a None/null value. The `or []` ensures we always get an iterable,
# preventing "TypeError: 'NoneType' object is not iterable".
assessments = redacted_response.get("assessments") or []
if not assessments:
return redacted_response
@@ -91,13 +96,13 @@ def _redact_pii_matches(response_json: dict) -> dict:
# Redact PII entities in sensitive information policy
sensitive_info_policy = assessment.get("sensitiveInformationPolicy")
if sensitive_info_policy:
pii_entities = sensitive_info_policy.get("piiEntities", [])
pii_entities = sensitive_info_policy.get("piiEntities") or []
for pii_entity in pii_entities:
if "match" in pii_entity:
pii_entity["match"] = "[REDACTED]"
# Redact regex matches
regexes = sensitive_info_policy.get("regexes", [])
regexes = sensitive_info_policy.get("regexes") or []
for regex_match in regexes:
if "match" in regex_match:
regex_match["match"] = "[REDACTED]"
@@ -105,12 +110,12 @@ def _redact_pii_matches(response_json: dict) -> dict:
# Redact custom word matches in word policy
word_policy = assessment.get("wordPolicy")
if word_policy:
custom_words = word_policy.get("customWords", [])
custom_words = word_policy.get("customWords") or []
for custom_word in custom_words:
if "match" in custom_word:
custom_word["match"] = "[REDACTED]"
managed_words = word_policy.get("managedWordLists", [])
managed_words = word_policy.get("managedWordLists") or []
for managed_word in managed_words:
if "match" in managed_word:
managed_word["match"] = "[REDACTED]"
@@ -825,7 +830,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
return False
# Check assessments to determine if any actions were BLOCKED (vs ANONYMIZED)
assessments = response.get("assessments", [])
# NOTE: Use `or []` instead of default param to handle explicit null from Bedrock API.
# See _redact_pii_matches() for detailed explanation of the null safety pattern.
assessments = response.get("assessments") or []
if not assessments:
return False
@@ -833,7 +840,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
# Check topic policy
topic_policy = assessment.get("topicPolicy")
if topic_policy:
topics = topic_policy.get("topics", [])
topics = topic_policy.get("topics") or []
for topic in topics:
if topic.get("action") == "BLOCKED":
return True
@@ -841,7 +848,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
# Check content policy
content_policy = assessment.get("contentPolicy")
if content_policy:
filters = content_policy.get("filters", [])
filters = content_policy.get("filters") or []
for filter_item in filters:
if filter_item.get("action") == "BLOCKED":
return True
@@ -849,11 +856,11 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
# Check word policy
word_policy = assessment.get("wordPolicy")
if word_policy:
custom_words = word_policy.get("customWords", [])
custom_words = word_policy.get("customWords") or []
for custom_word in custom_words:
if custom_word.get("action") == "BLOCKED":
return True
managed_words = word_policy.get("managedWordLists", [])
managed_words = word_policy.get("managedWordLists") or []
for managed_word in managed_words:
if managed_word.get("action") == "BLOCKED":
return True
@@ -861,12 +868,12 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
# Check sensitive information policy
sensitive_info_policy = assessment.get("sensitiveInformationPolicy")
if sensitive_info_policy:
pii_entities = sensitive_info_policy.get("piiEntities", [])
pii_entities = sensitive_info_policy.get("piiEntities") or []
if pii_entities:
for pii_entity in pii_entities:
if pii_entity.get("action") == "BLOCKED":
return True
regexes = sensitive_info_policy.get("regexes", [])
regexes = sensitive_info_policy.get("regexes") or []
if regexes:
for regex in regexes:
if regex.get("action") == "BLOCKED":
@@ -875,7 +882,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
# Check contextual grounding policy
contextual_grounding_policy = assessment.get("contextualGroundingPolicy")
if contextual_grounding_policy:
grounding_filters = contextual_grounding_policy.get("filters", [])
grounding_filters = contextual_grounding_policy.get("filters") or []
for grounding_filter in grounding_filters:
if grounding_filter.get("action") == "BLOCKED":
return True
@@ -1534,7 +1541,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
Raises:
Exception: If content is blocked by Bedrock guardrail
"""
texts = inputs.get("texts", [])
# NOTE: Use `or []` to handle case where inputs["texts"] is explicitly None.
# dict.get("texts", []) would return None if the key exists with a None value.
texts = inputs.get("texts") or []
try:
verbose_proxy_logger.debug(
f"Bedrock Guardrail: Applying guardrail to {len(texts)} text(s)"
@@ -274,6 +274,15 @@ class NomaV2Guardrail(CustomGuardrail):
if application_id is None:
application_id = self._get_non_empty_str(self.application_id)
# Fall back to API key alias for per-key traceability in Noma dashboard
# (ports v1 fallback from PR #16832).
if application_id is None:
application_id = self._get_non_empty_str(
request_data.get("litellm_metadata", {}).get("user_api_key_alias")
) or self._get_non_empty_str(
request_data.get("metadata", {}).get("user_api_key_alias")
)
try:
payload = self._build_scan_payload(
inputs=inputs,
@@ -52,6 +52,20 @@ def _get_a2a_request_id(
endpoint_guardrail_translation_mappings = None
def _ensure_litellm_metadata(data: dict, user_api_key_dict: UserAPIKeyAuth) -> None:
"""Populate data['litellm_metadata'] from user_api_key_dict if absent."""
if "litellm_metadata" not in data:
from litellm.llms.base_llm.guardrail_translation.base_translation import (
BaseTranslation,
)
user_metadata = BaseTranslation.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
data["litellm_metadata"] = user_metadata
class UnifiedLLMGuardrails(CustomLogger):
def __init__(
self,
@@ -120,6 +134,8 @@ class UnifiedLLMGuardrails(CustomLogger):
CallTypes(call_type)
]()
_ensure_litellm_metadata(data, user_api_key_dict)
data = await endpoint_translation.process_input_messages(
data=data,
guardrail_to_apply=guardrail_to_apply,
@@ -177,6 +193,8 @@ class UnifiedLLMGuardrails(CustomLogger):
CallTypes(call_type)
]()
_ensure_litellm_metadata(data, user_api_key_dict)
return await endpoint_translation.process_input_messages(
data=data,
guardrail_to_apply=guardrail_to_apply,
+10 -2
View File
@@ -11,7 +11,11 @@ from typing import List, Optional
import litellm
logger = logging.getLogger(__name__)
from litellm.constants import DEFAULT_HEALTH_CHECK_PROMPT, HEALTH_CHECK_TIMEOUT_SECONDS
from litellm.constants import (
BACKGROUND_HEALTH_CHECK_MAX_TOKENS,
DEFAULT_HEALTH_CHECK_PROMPT,
HEALTH_CHECK_TIMEOUT_SECONDS,
)
ILLEGAL_DISPLAY_PARAMS = [
"messages",
@@ -242,7 +246,9 @@ async def _perform_health_check(
cleaned["model_id"] = _model_id
if isinstance(is_healthy, Exception):
exceptions_by_model_id[_model_id] = is_healthy
cleaned["exception_status"] = getattr(is_healthy, "status_code", 500)
cleaned["exception_status"] = getattr(
is_healthy, "status_code", 500
)
unhealthy_endpoints.append(cleaned)
return healthy_endpoints, unhealthy_endpoints, exceptions_by_model_id
@@ -301,6 +307,8 @@ def _update_litellm_params_for_health_check(
_health_check_max_tokens = model_info.get("health_check_max_tokens", None)
if _health_check_max_tokens is not None:
litellm_params["max_tokens"] = _health_check_max_tokens
elif BACKGROUND_HEALTH_CHECK_MAX_TOKENS is not None:
litellm_params["max_tokens"] = BACKGROUND_HEALTH_CHECK_MAX_TOKENS
elif "*" not in (
model_info.get("health_check_model") or litellm_params.get("model") or ""
):
@@ -84,7 +84,7 @@ class SharedHealthCheckManager:
"Pod %s failed to acquire health check lock", self.pod_id
)
return acquired
return bool(acquired)
except Exception as e:
verbose_proxy_logger.error("Error acquiring health check lock: %s", str(e))
return False
@@ -98,11 +98,20 @@ class TeamMemberPermissionChecks:
)
# 5. Check if the team member has permissions for the endpoint
TeamMemberPermissionChecks.does_team_member_have_permissions_for_endpoint(
team_member_object=key_assigned_user_in_team,
team_table=team_table,
route=route,
has_permission = (
TeamMemberPermissionChecks.does_team_member_have_permissions_for_endpoint(
team_member_object=key_assigned_user_in_team,
team_table=team_table,
route=route,
)
)
if not has_permission:
raise ProxyException(
message=f"User {user_api_key_dict.user_id} does not belong to team {team_table.team_id}. Team-scoped key management endpoints can only be used for keys in your own team.",
type=ProxyErrorTypes.team_member_permission_error,
param=route,
code=401,
)
@staticmethod
def does_team_member_have_permissions_for_endpoint(
+74 -27
View File
@@ -54,7 +54,7 @@ from litellm.constants import (
LITELLM_SETTINGS_SAFE_DB_OVERRIDES,
LITELLM_UI_ALLOW_HEADERS,
LITELLM_UI_SESSION_DURATION,
DAILY_TAG_SPEND_BATCH_MULTIPLIER
DAILY_TAG_SPEND_BATCH_MULTIPLIER,
)
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
@@ -1139,7 +1139,54 @@ async def openai_exception_handler(request: Request, exc: ProxyException):
router = APIRouter()
origins = ["*"]
def _get_cors_config(
cors_origins_env: Optional[str] = None,
cors_credentials_env: Optional[str] = None,
):
"""
Compute CORS allowed origins and credentials flag from environment variables.
Extracted into a function so it can be unit-tested without reloading the module.
Args:
cors_origins_env: Value of LITELLM_CORS_ORIGINS (defaults to os.getenv).
cors_credentials_env: Value of LITELLM_CORS_ALLOW_CREDENTIALS (defaults to os.getenv).
Returns:
Tuple[List[str], bool]: (origins, allow_credentials)
"""
_origins_raw = (
cors_origins_env
if cors_origins_env is not None
else os.getenv("LITELLM_CORS_ORIGINS")
)
if _origins_raw is None or _origins_raw.strip() == "":
computed_origins = ["*"]
else:
computed_origins = [o.strip() for o in _origins_raw.split(",") if o.strip()]
# Disable credentials by default when wildcard origins are used — combining
# allow_origins=["*"] with allow_credentials=True causes Starlette to reflect
# the incoming Origin header, allowing any site to make credentialed requests.
# Set LITELLM_CORS_ALLOW_CREDENTIALS=true to explicitly restore the old behaviour
# (e.g. for non-browser clients that relied on the Access-Control-Allow-Credentials
# header being present regardless of origin).
_credentials_raw = (
cors_credentials_env
if cors_credentials_env is not None
else os.getenv("LITELLM_CORS_ALLOW_CREDENTIALS")
)
if _credentials_raw is not None:
computed_credentials = _credentials_raw.strip().lower() == "true"
else:
computed_credentials = "*" not in computed_origins
return computed_origins, computed_credentials
origins, allow_cors_credentials = _get_cors_config()
# get current directory
@@ -1466,7 +1513,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_credentials=allow_cors_credentials,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=LITELLM_UI_ALLOW_HEADERS,
@@ -1870,26 +1917,20 @@ async def update_cache( # noqa: PLR0915
## CHECK IF USER PROJECTED SPEND > SOFT LIMIT
if (
existing_spend_obj.soft_budget_cooldown is False
and existing_spend_obj.litellm_budget_table is not None
and existing_spend_obj.soft_budget is not None
and (
_is_projected_spend_over_limit(
current_spend=new_spend,
soft_budget_limit=existing_spend_obj.litellm_budget_table[
"soft_budget"
],
soft_budget_limit=existing_spend_obj.soft_budget,
)
is True
)
):
projected_spend, projected_exceeded_date = _get_projected_spend_over_limit(
current_spend=new_spend,
soft_budget_limit=existing_spend_obj.litellm_budget_table.get(
"soft_budget", None
),
soft_budget_limit=existing_spend_obj.soft_budget,
) # type: ignore
soft_limit = existing_spend_obj.litellm_budget_table.get(
"soft_budget", float("inf")
)
soft_limit = existing_spend_obj.soft_budget
call_info = CallInfo(
token=existing_spend_obj.token or "",
spend=new_spend,
@@ -1897,7 +1938,7 @@ async def update_cache( # noqa: PLR0915
max_budget=soft_limit,
user_id=existing_spend_obj.user_id,
projected_spend=projected_spend,
projected_exceeded_date=projected_exceeded_date,
projected_exceeded_date=str(projected_exceeded_date),
event_group=Litellm_EntityType.KEY,
)
# alert user
@@ -2315,9 +2356,13 @@ def _write_health_state_to_router_cache(
exception_status = getattr(original_exception, "status_code", 500)
if llm_router.health_check_ignore_transient_errors and exception_status in (
429,
408,
if (
llm_router.health_check_ignore_transient_errors
and exception_status
in (
429,
408,
)
):
continue
@@ -6286,7 +6331,9 @@ class ProxyStartupEvent:
### UPDATE DAILY TAG SPEND (separate scheduler job with longer interval) ###
## Reduces QPS as there are more tags for a single request
tag_spend_update_interval = int(batch_writing_interval * DAILY_TAG_SPEND_BATCH_MULTIPLIER)
tag_spend_update_interval = int(
batch_writing_interval * DAILY_TAG_SPEND_BATCH_MULTIPLIER
)
from litellm.proxy.utils import update_daily_tag_spend
scheduler.add_job(
@@ -7131,9 +7178,9 @@ async def chat_completion( # noqa: PLR0915
hasattr(user_api_key_dict, "organization_alias")
and user_api_key_dict.organization_alias is not None
):
data["metadata"]["user_api_key_org_alias"] = (
user_api_key_dict.organization_alias
)
data["metadata"][
"user_api_key_org_alias"
] = user_api_key_dict.organization_alias
if (
hasattr(user_api_key_dict, "agent_id")
and user_api_key_dict.agent_id is not None
@@ -7312,9 +7359,9 @@ async def completion( # noqa: PLR0915
hasattr(user_api_key_dict, "organization_alias")
and user_api_key_dict.organization_alias is not None
):
data["metadata"]["user_api_key_org_alias"] = (
user_api_key_dict.organization_alias
)
data["metadata"][
"user_api_key_org_alias"
] = user_api_key_dict.organization_alias
if (
hasattr(user_api_key_dict, "agent_id")
and user_api_key_dict.agent_id is not None
@@ -7561,9 +7608,9 @@ async def embeddings( # noqa: PLR0915
hasattr(user_api_key_dict, "organization_alias")
and user_api_key_dict.organization_alias is not None
):
data["metadata"]["user_api_key_org_alias"] = (
user_api_key_dict.organization_alias
)
data["metadata"][
"user_api_key_org_alias"
] = user_api_key_dict.organization_alias
if (
hasattr(user_api_key_dict, "agent_id")
and user_api_key_dict.agent_id is not None
+2
View File
@@ -289,6 +289,7 @@ model LiteLLM_MCPServerTable {
server_name String?
alias String?
description String?
instructions String?
url String?
spec_path String?
transport String @default("sse")
@@ -1045,6 +1046,7 @@ model LiteLLM_HealthCheckTable {
@@index([model_name])
@@index([checked_at])
@@index([status])
@@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx")
}
// Search Tools table for storing search tool configurations
+27 -33
View File
@@ -4526,29 +4526,20 @@ class PrismaClient:
async def get_all_latest_health_checks(self):
"""
Get the latest health check for each model
Get the latest health check for each model.
Uses DB-level DISTINCT ON (model_id, model_name) with ORDER BY checked_at DESC
(via Prisma ``distinct`` + ``order``) so we never load the full history into memory.
"""
try:
# Get all unique model names first
all_checks = await self.db.litellm_healthchecktable.find_many(
order={"checked_at": "desc"}
return await self.db.litellm_healthchecktable.find_many(
distinct=["model_id", "model_name"],
order=[
{"model_id": "asc"},
{"model_name": "asc"},
{"checked_at": "desc"},
],
)
# Group by model_name and get the latest for each
latest_checks = {}
for check in all_checks:
# Create a unique key: prefer model_id if available, otherwise use model_name
# This ensures we get the latest check for each unique model
if check.model_id:
key = (check.model_id, check.model_name)
else:
key = (None, check.model_name)
# Only add if we haven't seen this key yet (since checks are ordered by checked_at desc)
if key not in latest_checks:
latest_checks[key] = check
return list(latest_checks.values())
except Exception as e:
verbose_proxy_logger.error(f"Error getting all latest health checks: {e}")
return []
@@ -5322,19 +5313,6 @@ def get_error_message_str(e: Exception) -> str:
return error_message
def _get_openapi_url() -> Optional[str]:
"""
Get the OpenAPI schema URL from the environment variables.
- If NO_OPENAPI is True, return None.
- Otherwise, default to "/openapi.json".
"""
if str_to_bool(os.getenv("NO_OPENAPI")) is True:
return None
return "/openapi.json"
def _get_redoc_url() -> Optional[str]:
"""
Get the Redoc URL from the environment variables.
@@ -5368,6 +5346,22 @@ def _get_docs_url() -> Optional[str]:
return "/"
def _get_openapi_url() -> Optional[str]:
"""
Get the OpenAPI JSON URL from the environment variables.
- If OPENAPI_URL is set, return it.
- If NO_OPENAPI is True, return None.
- Otherwise, default to "/openapi.json".
"""
if openapi_url := os.getenv("OPENAPI_URL"):
return openapi_url
if str_to_bool(os.getenv("NO_OPENAPI")) is True:
return None
return "/openapi.json"
def handle_exception_on_proxy(e: Exception) -> ProxyException:
"""
+2
View File
@@ -396,6 +396,7 @@ class PrometheusMetricLabels:
UserAPIKeyLabelNames.CLIENT_IP.value,
UserAPIKeyLabelNames.USER_AGENT.value,
UserAPIKeyLabelNames.MODEL_ID.value,
UserAPIKeyLabelNames.API_PROVIDER.value,
]
litellm_spend_metric = [
@@ -410,6 +411,7 @@ class PrometheusMetricLabels:
UserAPIKeyLabelNames.CLIENT_IP.value,
UserAPIKeyLabelNames.USER_AGENT.value,
UserAPIKeyLabelNames.MODEL_ID.value,
UserAPIKeyLabelNames.API_PROVIDER.value,
]
litellm_input_tokens_metric = [
+10 -9
View File
@@ -27,20 +27,21 @@ class MCPServer(BaseModel):
spec_path: Optional[str] = None
auth_type: Optional[MCPAuthType] = None
authentication_token: Optional[str] = None
instructions: Optional[str] = None
mcp_info: Optional[MCPInfo] = None
extra_headers: Optional[
List[str]
] = None # allow admin to specify which headers to forward from client to the MCP server
extra_headers: Optional[List[str]] = (
None # allow admin to specify which headers to forward from client to the MCP server
)
allowed_tools: Optional[List[str]] = None
disallowed_tools: Optional[List[str]] = None
tool_name_to_display_name: Optional[Dict[str, str]] = None
tool_name_to_description: Optional[Dict[str, str]] = None
allowed_params: Optional[
Dict[str, List[str]]
] = None # map of tool names to allowed parameter lists
static_headers: Optional[
Dict[str, str]
] = None # static headers to forward to the MCP server
allowed_params: Optional[Dict[str, List[str]]] = (
None # map of tool names to allowed parameter lists
)
static_headers: Optional[Dict[str, str]] = (
None # static headers to forward to the MCP server
)
# OAuth-specific fields
client_id: Optional[str] = None
client_secret: Optional[str] = None
+2
View File
@@ -2647,6 +2647,8 @@ class StandardLoggingAdditionalHeaders(TypedDict, total=False):
x_ratelimit_limit_tokens: int
x_ratelimit_remaining_requests: int
x_ratelimit_remaining_tokens: int
x_ratelimit_reset_requests: str
x_ratelimit_reset_tokens: str
class StandardLoggingHiddenParams(TypedDict):
+12 -2
View File
@@ -4815,11 +4815,21 @@ def _apply_openai_param_overrides(
If user passes in allowed_openai_params, apply them to optional_params
These params will get passed as is to the LLM API since the user opted in to passing them in the request
Only params the caller actually sent are forwarded. Previously this
function unconditionally wrote `None` for any allowed param missing from
the request, which then reached the provider SDK as a top-level kwarg it
did not recognize (e.g. openai SDK raising
`AsyncCompletions.create() got an unexpected keyword argument 'enable_thinking'`).
See https://github.com/BerriAI/litellm/issues/25697
"""
if allowed_openai_params:
for param in allowed_openai_params:
if param not in optional_params:
optional_params[param] = non_default_params.pop(param, None)
if param in optional_params:
continue
if param not in non_default_params:
continue
optional_params[param] = non_default_params.pop(param)
return optional_params
+52
View File
@@ -25164,6 +25164,58 @@
"supports_web_search": true,
"tpm": 800000
},
"openrouter/google/gemini-3.1-flash-lite-preview": {
"cache_read_input_token_cost": 2.5e-08,
"cache_read_input_token_cost_per_audio_token": 5e-08,
"input_cost_per_audio_token": 5e-07,
"input_cost_per_token": 2.5e-07,
"litellm_provider": "openrouter",
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_images_per_prompt": 3000,
"max_input_tokens": 1048576,
"max_output_tokens": 65536,
"max_pdf_size_mb": 30,
"max_tokens": 65536,
"max_video_length": 1,
"max_videos_per_prompt": 10,
"mode": "chat",
"output_cost_per_reasoning_token": 1.5e-06,
"output_cost_per_token": 1.5e-06,
"rpm": 2000,
"source": "https://ai.google.dev/pricing/gemini-3",
"supported_endpoints": [
"/v1/chat/completions",
"/v1/completions",
"/v1/batch"
],
"supported_modalities": [
"text",
"image",
"audio",
"video"
],
"supported_output_modalities": [
"text"
],
"supports_audio_input": true,
"supports_audio_output": false,
"supports_code_execution": true,
"supports_file_search": true,
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_pdf_input": true,
"supports_prompt_caching": true,
"supports_reasoning": true,
"supports_response_schema": true,
"supports_system_messages": true,
"supports_tool_choice": true,
"supports_url_context": true,
"supports_video_input": true,
"supports_vision": true,
"supports_web_search": true,
"tpm": 800000
},
"openrouter/google/gemini-3.1-pro-preview": {
"cache_read_input_token_cost": 2e-07,
"cache_read_input_token_cost_above_200k_tokens": 4e-07,
+2
View File
@@ -289,6 +289,7 @@ model LiteLLM_MCPServerTable {
server_name String?
alias String?
description String?
instructions String?
url String?
spec_path String?
transport String @default("sse")
@@ -1045,6 +1046,7 @@ model LiteLLM_HealthCheckTable {
@@index([model_name])
@@index([checked_at])
@@index([status])
@@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx")
}
// Search Tools table for storing search tool configurations
@@ -0,0 +1,182 @@
#!/usr/bin/env python3
"""
Bench LiteLLM_HealthCheckTable + PrismaClient
- set DATABASE_URL to your Postgres
- Run ```prisma generate``` to install prisma client before running test )
- This test writes to the default "public" database. Make sure to run cleanup after testing
"""
from __future__ import annotations
import argparse
import asyncio
import gc
import os
import sys
import time
import tracemalloc
from datetime import datetime, timedelta, timezone
from typing import Any, List
SEED_MARKER = "benchmark_get_all_latest_health_checks.py" # Utility Marker for cleanup process.
def _rss_kb_linux() -> int:
try:
with open("/proc/self/status", encoding="utf-8") as f:
for line in f:
if line.startswith("VmRSS:"):
return int(line.split()[1])
except OSError:
pass
return 0
def _fmt_kb(kb: int) -> str:
if kb <= 0:
return "n/a"
return f"{kb} KiB (~{kb / 1024.0:.1f} MiB)"
def _build_batch(
*,
batch_index: int,
batch_size: int,
num_models: int,
base_time: datetime,
) -> List[dict[str, Any]]:
rows: List[dict[str, Any]] = []
for i in range(batch_size):
global_i = batch_index * batch_size + i
model_idx = global_i % max(num_models, 1)
model_name = f"bench-model-{model_idx}"
model_id = f"bench-mid-{model_idx}" if model_idx % 2 == 0 else None
checked_at = base_time - timedelta(seconds=global_i)
rows.append(
{
"model_name": model_name,
"model_id": model_id,
"status": "healthy" if global_i % 3 else "unhealthy",
"healthy_count": 1,
"unhealthy_count": 0,
"checked_by": SEED_MARKER,
"checked_at": checked_at,
}
)
return rows
async def _seed(
prisma: Any,
*,
total_rows: int,
batch_size: int,
num_models: int,
) -> None:
db = prisma.db
base_time = datetime.now(timezone.utc)
inserted = 0
batch_idx = 0
while inserted < total_rows:
n = min(batch_size, total_rows - inserted)
await db.litellm_healthchecktable.create_many(
data=_build_batch(
batch_index=batch_idx,
batch_size=n,
num_models=num_models,
base_time=base_time,
)
)
inserted += n
batch_idx += 1
if batch_idx % 10 == 0:
print(f" {inserted}/{total_rows}", flush=True)
print(f"Seeded {inserted} rows ({SEED_MARKER}).")
async def _cleanup(prisma: Any) -> None:
result = await prisma.db.litellm_healthchecktable.delete_many(
where={"checked_by": SEED_MARKER},
)
n = getattr(result, "count", result)
print(f"Deleted {n} rows.")
async def _bench(prisma: Any) -> None:
gc.collect()
rss0 = _rss_kb_linux()
print(f"RSS (after gc): {_fmt_kb(rss0)}")
tracemalloc.start()
t0 = time.perf_counter()
try:
rows = await prisma.get_all_latest_health_checks()
finally:
elapsed = time.perf_counter() - t0
_, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
gc.collect()
rss1 = _rss_kb_linux()
print(f"get_all_latest_health_checks: {len(rows)} rows in {elapsed:.2f}s")
print(f"tracemalloc peak: {peak / 1e6:.2f} MiB")
print(f"RSS after: {_fmt_kb(rss1)}")
async def _amain() -> int:
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("action", choices=("seed", "bench", "cleanup"))
p.add_argument("--rows", type=int, default=10_000)
p.add_argument("--batch-size", type=int, default=1000)
p.add_argument("--num-models", type=int, default=50)
args = p.parse_args()
database_url = os.getenv("DATABASE_URL")
if not database_url:
print("Set DATABASE_URL.", file=sys.stderr)
return 1
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
from litellm.caching.caching import DualCache
from litellm.proxy.proxy_cli import append_query_params
from litellm.proxy.utils import PrismaClient, ProxyLogging
db_url = append_query_params(
database_url, {"connection_limit": 100, "pool_timeout": 60}
)
prisma = PrismaClient(
database_url=db_url,
proxy_logging_obj=ProxyLogging(user_api_key_cache=DualCache()),
)
try:
await prisma.connect()
except Exception as e:
print(f"Connect failed: {e}", file=sys.stderr)
return 1
try:
if args.action == "seed":
await _seed(
prisma,
total_rows=args.rows,
batch_size=args.batch_size,
num_models=args.num_models,
)
elif args.action == "bench":
await _bench(prisma)
else:
await _cleanup(prisma)
finally:
try:
await prisma.disconnect()
except Exception:
pass
return 0
if __name__ == "__main__":
raise SystemExit(asyncio.run(_amain()))
@@ -605,6 +605,7 @@ def test_increment_top_level_request_and_spend_metrics(prometheus_logger):
org_alias=None,
model="gpt-3.5-turbo",
model_id="model-123",
api_provider="openai",
client_ip=None,
user_agent=None,
)
@@ -623,6 +624,7 @@ def test_increment_top_level_request_and_spend_metrics(prometheus_logger):
org_alias=None,
model="gpt-3.5-turbo",
model_id="model-123",
api_provider="openai",
client_ip=None,
user_agent=None,
)
@@ -5,7 +5,10 @@ import pytest
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import BedrockGuardrail
from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
BedrockGuardrail,
_redact_pii_matches,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from unittest.mock import MagicMock, AsyncMock, patch
@@ -1601,3 +1604,128 @@ async def test_bedrock_guardrail_post_call_success_hook_no_output_text():
# If no error is raised and result is None, then the test passes
assert result is None
print("✅ No output text in response test passed")
@pytest.mark.asyncio
async def test__redact_pii_matches_null_list_fields():
"""Test that explicit null values from Bedrock API are handled correctly.
The Bedrock API can return explicit JSON null for list fields like
piiEntities, regexes, customWords, managedWordLists. This would cause
TypeError: 'NoneType' object is not iterable if not handled.
"""
# Test 1: null piiEntities and regexes
response_with_null_pii = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [
{
"sensitiveInformationPolicy": {
"piiEntities": None,
"regexes": None,
}
}
],
}
redacted = _redact_pii_matches(response_with_null_pii)
assert redacted is not None
assert redacted["assessments"][0]["sensitiveInformationPolicy"]["piiEntities"] is None
assert redacted["assessments"][0]["sensitiveInformationPolicy"]["regexes"] is None
# Test 2: null customWords and managedWordLists
response_with_null_words = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [
{
"wordPolicy": {
"customWords": None,
"managedWordLists": None,
}
}
],
}
redacted = _redact_pii_matches(response_with_null_words)
assert redacted is not None
assert redacted["assessments"][0]["wordPolicy"]["customWords"] is None
assert redacted["assessments"][0]["wordPolicy"]["managedWordLists"] is None
# Test 3: null assessments at top level
response_with_null_assessments = {
"action": "GUARDRAIL_INTERVENED",
"assessments": None,
}
redacted = _redact_pii_matches(response_with_null_assessments)
assert redacted is not None
@pytest.mark.asyncio
async def test__redact_pii_matches_malformed_response():
"""Test _redact_pii_matches with malformed response (should not crash)"""
# Test with completely malformed response
malformed_response = {
"action": "GUARDRAIL_INTERVENED",
"assessments": "not_a_list",
}
redacted_response = _redact_pii_matches(malformed_response)
assert redacted_response == malformed_response
# Test with missing keys
missing_keys_response = {
"action": "GUARDRAIL_INTERVENED",
}
redacted_response = _redact_pii_matches(missing_keys_response)
assert redacted_response == missing_keys_response
@pytest.mark.asyncio
async def test_should_raise_guardrail_blocked_exception_null_fields():
"""Test that _should_raise_guardrail_blocked_exception handles null list fields.
Validates the or [] null-safety pattern works for all policy fields
in _should_raise_guardrail_blocked_exception.
"""
guardrail = BedrockGuardrail(
guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT"
)
# Test with null assessments
response_null_assessments = {
"action": "GUARDRAIL_INTERVENED",
"assessments": None,
}
assert guardrail._should_raise_guardrail_blocked_exception(response_null_assessments) is False
# Test with null topics in topicPolicy
response_null_topics = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [{"topicPolicy": {"topics": None}}],
}
assert guardrail._should_raise_guardrail_blocked_exception(response_null_topics) is False
# Test with null filters in contentPolicy
response_null_filters = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [{"contentPolicy": {"filters": None}}],
}
assert guardrail._should_raise_guardrail_blocked_exception(response_null_filters) is False
# Test with null customWords and managedWordLists in wordPolicy
response_null_words = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [{"wordPolicy": {"customWords": None, "managedWordLists": None}}],
}
assert guardrail._should_raise_guardrail_blocked_exception(response_null_words) is False
# Test with null piiEntities and regexes in sensitiveInformationPolicy
response_null_pii = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [{"sensitiveInformationPolicy": {"piiEntities": None, "regexes": None}}],
}
assert guardrail._should_raise_guardrail_blocked_exception(response_null_pii) is False
# Test with null filters in contextualGroundingPolicy
response_null_grounding = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [{"contextualGroundingPolicy": {"filters": None}}],
}
assert guardrail._should_raise_guardrail_blocked_exception(response_null_grounding) is False
@@ -5,6 +5,7 @@ sys.path.insert(
0, os.path.abspath("../../")
) # Adds the parent directory to the system path
import httpx
import pytest
from litellm.llms.azure.common_utils import process_azure_headers
from httpx import Headers
@@ -144,6 +144,45 @@ def test_get_optional_params_with_allowed_openai_params():
assert optional_params["reasoning_effort"] == reasoning_effort
def test_allowed_openai_params_does_not_forward_unset_params():
"""
Regression test for https://github.com/BerriAI/litellm/issues/25697
When a user lists a param in ``allowed_openai_params`` but does not
actually send that param in the request, litellm must not forward it
to the provider SDK as ``None``. The openai SDK rejects unknown
top-level kwargs with
``AsyncCompletions.create() got an unexpected keyword argument 'enable_thinking'``.
Reproduces the reported config where the user listed both
``chat_template_kwargs`` and ``enable_thinking`` in
``allowed_openai_params`` and only sent ``chat_template_kwargs``
(with ``enable_thinking`` nested inside it). Previously the loop
added ``optional_params["enable_thinking"] = None`` which then
crashed the openai client.
"""
from litellm.utils import _apply_openai_param_overrides
chat_template_kwargs = {"enable_thinking": False}
optional_params: dict = {}
non_default_params = {"chat_template_kwargs": chat_template_kwargs}
result = _apply_openai_param_overrides(
optional_params=optional_params,
non_default_params=non_default_params,
allowed_openai_params=["chat_template_kwargs", "enable_thinking"],
)
assert result["chat_template_kwargs"] == chat_template_kwargs
# enable_thinking was NOT sent as a top-level param — it must not be
# forwarded to the provider SDK (openai AsyncCompletions.create would
# reject an unknown kwarg, even if its value is None).
assert "enable_thinking" not in result
# And the only entry actually moved out of non_default_params is
# the one the caller sent.
assert "chat_template_kwargs" not in non_default_params
def test_bedrock_optional_params_embeddings():
litellm.drop_params = True
optional_params = get_optional_params_embeddings(
@@ -0,0 +1,46 @@
"""
``_get_httpx_client`` + ``HTTPHandler.post`` (same pattern as Azure Anthropic sync path:
``_get_httpx_client(params={"timeout": ...})`` then ``post(..., timeout=...)``).
Uses https://httpbin.org/delay/10 with ``timeout=5`` the handler must raise :class:`~litellm.exceptions.Timeout`
before the 10s delay completes. Skips if httpbin is unreachable.
Lives under ``local_testing`` (not ``make test-unit``).
"""
import json
import os
import sys
import httpx
import pytest
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
)
from litellm.exceptions import Timeout as LitellmTimeout
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
_HTTPBIN_DELAY_S = 10
_PER_REQUEST_TIMEOUT_S = 5.0
_CLIENT_DEFAULT_TIMEOUT_S = 60.0
def test_post_delay_exceeds_per_request_timeout_raises():
try:
httpx.get("https://httpbin.org/get", timeout=5.0)
except Exception as e:
pytest.skip(f"httpbin.org unreachable: {e}")
handler = _get_httpx_client(params={"timeout": _CLIENT_DEFAULT_TIMEOUT_S})
try:
with pytest.raises(LitellmTimeout):
handler.post(
f"https://httpbin.org/delay/{_HTTPBIN_DELAY_S}",
headers={"content-type": "application/json"},
data=json.dumps({"model": "claude", "messages": []}),
timeout=_PER_REQUEST_TIMEOUT_S,
)
finally:
handler.close()
@@ -131,6 +131,55 @@ def test_get_cache_key_text_completion():
assert cache_key_2 == cache_key_3
def test_get_cache_key_responses_api():
"""
Regression test: two /v1/responses calls that differ only in
`instructions` (or any Responses-API-only param) must produce
different cache keys. Mirrors the chat / embedding / text-completion
cache-key tests above.
"""
cache = Cache()
base_kwargs = {
"model": "openai/gpt-4.1",
"input": [{"role": "user", "content": "what is the weather"}],
"temperature": 0.3,
}
kwargs_a = {
**base_kwargs,
"instructions": "summarize the weather on 10th May",
}
kwargs_b = {
**base_kwargs,
"instructions": "summarize the weather on 7th May",
}
key_a = cache.get_cache_key(**kwargs_a)
key_b = cache.get_cache_key(**kwargs_b)
assert isinstance(key_a, str) and len(key_a) > 0
assert key_a != key_b, "instructions must be part of the Responses API cache key"
# Sanity: identical payloads must still collide (cache hits still work)
key_a_again = cache.get_cache_key(**kwargs_a)
assert key_a == key_a_again
# Spot-check a handful of other Responses-only params individually.
for param, value_x, value_y in [
("previous_response_id", "resp_aaa", "resp_bbb"),
("reasoning", {"effort": "low"}, {"effort": "high"}),
("include", ["reasoning.encrypted_content"], []),
("max_output_tokens", 100, 500),
("background", True, False),
]:
kx = {**base_kwargs, param: value_x}
ky = {**base_kwargs, param: value_y}
assert cache.get_cache_key(**kx) != cache.get_cache_key(
**ky
), f"Responses-API param `{param}` is not part of the cache key"
def test_get_hashed_cache_key():
cache = Cache()
cache_key = "model:gpt-3.5-turbo,messages:Hello world"
@@ -158,12 +158,15 @@ def test_get_additional_headers():
additional_logging_headers = StandardLoggingPayloadSetup.get_additional_headers(
additional_headers
)
assert additional_logging_headers == {
"x_ratelimit_limit_requests": 2000,
"x_ratelimit_remaining_requests": 1999,
"x_ratelimit_limit_tokens": 160000,
"x_ratelimit_remaining_tokens": 160000,
}
# Typed rate-limit fields are coerced to int
assert additional_logging_headers is not None
assert additional_logging_headers.get("x_ratelimit_limit_requests") == 2000
assert additional_logging_headers.get("x_ratelimit_remaining_requests") == 1999
assert additional_logging_headers.get("x_ratelimit_limit_tokens") == 160000
assert additional_logging_headers.get("x_ratelimit_remaining_tokens") == 160000
# Provider-specific headers are preserved verbatim (not dropped)
assert additional_logging_headers.get("llm_provider-request-id") == "req_01F6CycZZPSHKRCCctcS1Vto"
assert additional_logging_headers.get("llm_provider-anthropic-ratelimit-requests-reset") == "2024-10-29T23:57:40Z"
def all_fields_present(standard_logging_metadata: StandardLoggingMetadata):
+117 -78
View File
@@ -417,17 +417,22 @@ async def test_streamable_http_mcp_handler_mock():
# Mock extract_mcp_auth_context to bypass auth checks in the handler
mock_auth_context = (None, None, None, {}, {}, {})
with patch(
"litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED",
True,
), patch(
"litellm.proxy._experimental.mcp_server.server.session_manager",
mock_session_manager,
), patch(
"litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context",
AsyncMock(return_value=mock_auth_context),
), patch(
"litellm.proxy._experimental.mcp_server.server.set_auth_context",
with (
patch(
"litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED",
True,
),
patch(
"litellm.proxy._experimental.mcp_server.server.session_manager",
mock_session_manager,
),
patch(
"litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context",
AsyncMock(return_value=mock_auth_context),
),
patch(
"litellm.proxy._experimental.mcp_server.server.set_auth_context",
),
):
from litellm.proxy._experimental.mcp_server.server import (
handle_streamable_http_mcp,
@@ -471,17 +476,22 @@ async def test_sse_mcp_handler_mock():
[],
)
with patch(
"litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED",
True,
), patch(
"litellm.proxy._experimental.mcp_server.server.sse_session_manager",
mock_sse_session_manager,
), patch(
"litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context",
new=AsyncMock(return_value=mock_auth_result),
), patch(
"litellm.proxy._experimental.mcp_server.server.set_auth_context",
with (
patch(
"litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED",
True,
),
patch(
"litellm.proxy._experimental.mcp_server.server.sse_session_manager",
mock_sse_session_manager,
),
patch(
"litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context",
new=AsyncMock(return_value=mock_auth_result),
),
patch(
"litellm.proxy._experimental.mcp_server.server.set_auth_context",
),
):
from litellm.proxy._experimental.mcp_server.server import handle_sse_mcp
@@ -833,7 +843,9 @@ async def test_get_tools_from_mcp_servers():
mock_manager.get_allowed_mcp_servers = AsyncMock(
return_value=["server1_id", "server2_id"]
)
mock_manager.get_mcp_server_by_id = lambda server_id: mock_server_1 if server_id == "server1_id" else mock_server_2
mock_manager.get_mcp_server_by_id = lambda server_id: (
mock_server_1 if server_id == "server1_id" else mock_server_2
)
mock_manager._get_tools_from_server = AsyncMock(return_value=[mock_tool_1])
# Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test)
mock_manager.filter_server_ids_by_ip_with_info = MagicMock(
@@ -859,7 +871,10 @@ async def test_get_tools_from_mcp_servers():
mock_manager_2.get_allowed_mcp_servers = AsyncMock(
return_value=["server1_id", "server2_id"]
)
mock_manager_2.get_mcp_server_by_id = lambda server_id: mock_server_1 if server_id == "server1_id" else mock_server_2
mock_manager_2.get_mcp_server_by_id = lambda server_id: (
mock_server_1 if server_id == "server1_id" else mock_server_2
)
async def mock_get_tools_side_effect(
server,
mcp_auth_header=None,
@@ -900,7 +915,11 @@ async def test_get_tools_from_mcp_servers():
mock_manager.get_allowed_mcp_servers = AsyncMock(
return_value=["server1_id", "server2_id", "server3_id"]
)
mock_manager.get_mcp_server_by_id = lambda server_id: mock_server_1 if server_id == "server1_id" else (mock_server_2 if server_id == "server2_id" else mock_server_3)
mock_manager.get_mcp_server_by_id = lambda server_id: (
mock_server_1
if server_id == "server1_id"
else (mock_server_2 if server_id == "server2_id" else mock_server_3)
)
mock_manager._get_tools_from_server = AsyncMock(return_value=[mock_tool_1])
# Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test)
mock_manager.filter_server_ids_by_ip_with_info = MagicMock(
@@ -1050,15 +1069,15 @@ async def test_mcp_server_manager_access_groups_from_config():
# Should find config_server for group-a, both for group-b, other_server for group-c
import asyncio
server_ids_a = await MCPRequestHandler._get_mcp_servers_from_access_groups([
"group-a"
])
server_ids_b = await MCPRequestHandler._get_mcp_servers_from_access_groups([
"group-b"
])
server_ids_c = await MCPRequestHandler._get_mcp_servers_from_access_groups([
"group-c"
])
server_ids_a = await MCPRequestHandler._get_mcp_servers_from_access_groups(
["group-a"]
)
server_ids_b = await MCPRequestHandler._get_mcp_servers_from_access_groups(
["group-b"]
)
server_ids_c = await MCPRequestHandler._get_mcp_servers_from_access_groups(
["group-c"]
)
assert any(config_server.server_id == sid for sid in server_ids_a)
assert set(server_ids_b) == set(
[
@@ -1474,6 +1493,7 @@ async def test_add_update_server_with_alias():
mock_mcp_server.byok_api_key_help_url = None
mock_mcp_server.created_at = None
mock_mcp_server.updated_at = None
mock_mcp_server.instructions = None
# Add server to manager
await test_manager.add_server(mock_mcp_server)
@@ -1530,6 +1550,7 @@ async def test_add_update_server_without_alias():
mock_mcp_server.byok_api_key_help_url = None
mock_mcp_server.created_at = None
mock_mcp_server.updated_at = None
mock_mcp_server.instructions = None
# Add server to manager
await test_manager.add_server(mock_mcp_server)
@@ -1587,7 +1608,7 @@ async def test_add_update_server_fallback_to_server_id():
mock_mcp_server.byok_api_key_help_url = None
mock_mcp_server.created_at = None
mock_mcp_server.updated_at = None
mock_mcp_server.instructions = None
# Add server to manager
await test_manager.add_server(mock_mcp_server)
@@ -2151,8 +2172,12 @@ async def test_list_tool_rest_api_all_servers_with_auth():
for call_args in mock_get_tools.call_args_list
}
assert server_auth_map.get(mock_zapier_server) == "Bearer zapier_token"
assert server_auth_map.get(mock_slack_server) == "Bearer slack_token"
assert (
server_auth_map.get(mock_zapier_server) == "Bearer zapier_token"
)
assert (
server_auth_map.get(mock_slack_server) == "Bearer slack_token"
)
@pytest.mark.asyncio
@@ -2690,26 +2715,33 @@ async def test_call_mcp_tool_uses_manager_permission_lookup():
expected_response = [TextContent(type="text", text="ok")]
with patch.object(
global_mcp_server_manager,
"get_allowed_mcp_servers",
new_callable=AsyncMock,
) as mock_get_allowed, patch.object(
global_mcp_server_manager,
"get_mcp_server_by_id",
return_value=mock_server,
), patch.object(
global_mcp_server_manager,
"_get_mcp_server_from_tool_name",
return_value=mock_server,
) as mock_get_server, patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_tool_registry"
) as mock_tool_registry, patch(
"litellm.proxy._experimental.mcp_server.server._handle_managed_mcp_tool",
new_callable=AsyncMock,
) as mock_handle_managed, patch(
"litellm.proxy._experimental.mcp_server.server.MCPRequestHandler.is_tool_allowed",
return_value=True,
with (
patch.object(
global_mcp_server_manager,
"get_allowed_mcp_servers",
new_callable=AsyncMock,
) as mock_get_allowed,
patch.object(
global_mcp_server_manager,
"get_mcp_server_by_id",
return_value=mock_server,
),
patch.object(
global_mcp_server_manager,
"_get_mcp_server_from_tool_name",
return_value=mock_server,
) as mock_get_server,
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_tool_registry"
) as mock_tool_registry,
patch(
"litellm.proxy._experimental.mcp_server.server._handle_managed_mcp_tool",
new_callable=AsyncMock,
) as mock_handle_managed,
patch(
"litellm.proxy._experimental.mcp_server.server.MCPRequestHandler.is_tool_allowed",
return_value=True,
),
):
mock_get_allowed.return_value = [mock_server.server_id]
mock_tool_registry.get_tool.return_value = None
@@ -2759,27 +2791,34 @@ async def test_call_mcp_tool_resolves_unprefixed_tool_name_and_checks_permission
expected_response = [TextContent(type="text", text="ok")]
with patch.object(
global_mcp_server_manager,
"get_allowed_mcp_servers",
new_callable=AsyncMock,
) as mock_get_allowed, patch.object(
global_mcp_server_manager,
"get_mcp_server_by_id",
return_value=mock_server,
), patch.object(
global_mcp_server_manager,
"_get_mcp_server_from_tool_name",
return_value=mock_server,
) as mock_get_server, patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_tool_registry"
) as mock_tool_registry, patch(
"litellm.proxy._experimental.mcp_server.server._handle_managed_mcp_tool",
new_callable=AsyncMock,
) as mock_handle_managed, patch(
"litellm.proxy._experimental.mcp_server.server.MCPRequestHandler.is_tool_allowed",
return_value=True,
) as mock_is_allowed:
with (
patch.object(
global_mcp_server_manager,
"get_allowed_mcp_servers",
new_callable=AsyncMock,
) as mock_get_allowed,
patch.object(
global_mcp_server_manager,
"get_mcp_server_by_id",
return_value=mock_server,
),
patch.object(
global_mcp_server_manager,
"_get_mcp_server_from_tool_name",
return_value=mock_server,
) as mock_get_server,
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_tool_registry"
) as mock_tool_registry,
patch(
"litellm.proxy._experimental.mcp_server.server._handle_managed_mcp_tool",
new_callable=AsyncMock,
) as mock_handle_managed,
patch(
"litellm.proxy._experimental.mcp_server.server.MCPRequestHandler.is_tool_allowed",
return_value=True,
) as mock_is_allowed,
):
mock_get_allowed.return_value = [mock_server.server_id]
mock_tool_registry.get_tool.return_value = None
mock_handle_managed.return_value = expected_response
+25 -1
View File
@@ -10,7 +10,7 @@ import pytest
from fastapi import Request
from starlette.datastructures import State
from litellm.proxy.utils import _get_docs_url, _get_redoc_url
from litellm.proxy.utils import _get_docs_url, _get_openapi_url, _get_redoc_url
sys.path.insert(
0, os.path.abspath("../..")
@@ -735,6 +735,30 @@ def test_get_docs_url(env_vars, expected_url):
result = _get_docs_url()
assert result == expected_url
@pytest.mark.parametrize(
"env_vars, expected_url",
[
({}, "/openapi.json"), # default case
({"OPENAPI_URL": "/custom-openapi.json"}, "/custom-openapi.json"), # custom URL
(
{"OPENAPI_URL": "https://example.com/openapi.json"},
"https://example.com/openapi.json",
), # full URL
({"NO_OPENAPI": "True"}, None), # openapi disabled
],
)
def test_get_openapi_url(env_vars, expected_url):
# Clear relevant environment variables
for key in ["OPENAPI_URL", "NO_OPENAPI"]:
os.environ.pop(key, None)
# Set test environment variables
for key, value in env_vars.items():
os.environ[key] = value
result = _get_openapi_url()
assert result == expected_url
@pytest.mark.parametrize(
"request_tags, tags_to_add, expected_tags",
@@ -40,6 +40,7 @@ class TestMCPClient:
with pytest.raises(
ValueError, match="stdio_config is required for stdio transport"
):
async def _noop(session):
return None
@@ -251,11 +252,11 @@ class TestMCPClient:
server_url="http://example.com/sse",
transport_type="sse",
auth_type=MCPAuth.token,
auth_value="my-secret-token"
auth_value="my-secret-token",
)
headers = client._get_auth_headers()
assert "Authorization" in headers
assert headers["Authorization"] == "token my-secret-token"
@@ -266,27 +267,27 @@ class TestMCPClient:
server_url="http://example.com/sse",
transport_type="sse",
auth_type=MCPAuth.bearer_token,
auth_value="bearer-token"
auth_value="bearer-token",
)
headers = client._get_auth_headers()
assert headers["Authorization"] == "Bearer bearer-token"
# Test API key
client = MCPClient(
server_url="http://example.com/sse",
transport_type="sse",
auth_type=MCPAuth.api_key,
auth_value="api-key"
auth_value="api-key",
)
headers = client._get_auth_headers()
assert headers["X-API-Key"] == "api-key"
# Test basic auth (gets base64 encoded)
client = MCPClient(
server_url="http://example.com/sse",
transport_type="sse",
auth_type=MCPAuth.basic,
auth_value="user:pass"
auth_value="user:pass",
)
headers = client._get_auth_headers()
assert headers["Authorization"].startswith("Basic ")
@@ -298,11 +299,11 @@ class TestMCPClient:
transport_type="sse",
auth_type=MCPAuth.token,
auth_value="my-token",
extra_headers={"X-Custom-Header": "custom-value"}
extra_headers={"X-Custom-Header": "custom-value"},
)
headers = client._get_auth_headers()
assert headers["Authorization"] == "token my-token"
assert headers["X-Custom-Header"] == "custom-value"
@@ -312,5 +313,80 @@ class TestMCPClient:
assert MCPAuth.token.value == "token"
# ---------------------------------------------------------------------------
# _last_initialize_instructions capture
# ---------------------------------------------------------------------------
class TestMCPClientInstructionsCapture:
"""Tests for _last_initialize_instructions capture during session init."""
def test_initial_value_is_none(self):
"""Fresh client has no cached instructions."""
client = MCPClient(
server_url="http://example.com/mcp",
transport_type="http",
)
assert client._last_initialize_instructions is None
@pytest.mark.asyncio
@patch("litellm.experimental_mcp_client.client.ClientSession")
async def test_captures_instructions_from_initialize(self, mock_session_cls):
"""Instructions from upstream initialize() are captured and stripped."""
client = MCPClient(
server_url="http://example.com/mcp",
transport_type="http",
)
mock_session = AsyncMock()
init_result = MagicMock()
init_result.instructions = " upstream says hello "
mock_session.initialize = AsyncMock(return_value=init_result)
session_ctx = MagicMock()
session_ctx.__aenter__ = AsyncMock(return_value=mock_session)
session_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session_cls.return_value = session_ctx
transport_ctx = MagicMock()
transport_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
transport_ctx.__aexit__ = AsyncMock(return_value=False)
async def _op(session):
return "done"
await client._execute_session_operation(transport_ctx, _op)
assert client._last_initialize_instructions == "upstream says hello"
@pytest.mark.asyncio
@patch("litellm.experimental_mcp_client.client.ClientSession")
async def test_none_instructions_stays_none(self, mock_session_cls):
"""When upstream returns no instructions the field stays None."""
client = MCPClient(
server_url="http://example.com/mcp",
transport_type="http",
)
mock_session = AsyncMock()
init_result = MagicMock()
init_result.instructions = None
mock_session.initialize = AsyncMock(return_value=init_result)
session_ctx = MagicMock()
session_ctx.__aenter__ = AsyncMock(return_value=mock_session)
session_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session_cls.return_value = session_ctx
transport_ctx = MagicMock()
transport_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
transport_ctx.__aexit__ = AsyncMock(return_value=False)
async def _op(session):
return "done"
await client._execute_session_operation(transport_ctx, _op)
assert client._last_initialize_instructions is None
if __name__ == "__main__":
pytest.main([__file__])
@@ -69,6 +69,21 @@ def test_model_id_in_required_metrics():
print(f"{metric_name} contains model_id label")
def test_api_provider_in_spend_and_requests_metrics():
"""
Test that api_provider label is present in spend and requests metrics
so users can build spend-by-provider and request-count-by-provider dashboards.
"""
api_provider_label = UserAPIKeyLabelNames.API_PROVIDER.value
for metric_name in ["litellm_spend_metric", "litellm_requests_metric"]:
labels = PrometheusMetricLabels.get_labels(metric_name)
assert (
api_provider_label in labels
), f"Metric {metric_name} should contain api_provider label"
print(f"{metric_name} contains api_provider label")
def test_user_email_label_exists():
"""Test that the USER_EMAIL label is properly defined"""
assert UserAPIKeyLabelNames.USER_EMAIL.value == "user_email"
@@ -11,8 +11,7 @@ sys.path.insert(
import time
from litellm.constants import SENTRY_DENYLIST, SENTRY_PII_DENYLIST
from litellm.litellm_core_utils.litellm_logging import \
Logging as LitellmLogging
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.litellm_core_utils.litellm_logging import set_callbacks
from litellm.types.utils import ModelResponse, TextCompletionResponse
@@ -140,8 +139,7 @@ def test_sentry_environment():
def test_use_custom_pricing_for_model():
from litellm.litellm_core_utils.litellm_logging import \
use_custom_pricing_for_model
from litellm.litellm_core_utils.litellm_logging import use_custom_pricing_for_model
litellm_params = {
"custom_llm_provider": "azure",
@@ -156,8 +154,7 @@ def test_use_custom_pricing_for_model_via_litellm_metadata():
Generic API call routes (/messages, /responses) store model_info
under litellm_metadata, not metadata. Regression test for #23185.
"""
from litellm.litellm_core_utils.litellm_logging import \
use_custom_pricing_for_model
from litellm.litellm_core_utils.litellm_logging import use_custom_pricing_for_model
litellm_params = {
"litellm_metadata": {
@@ -173,8 +170,7 @@ def test_use_custom_pricing_for_model_via_litellm_metadata():
def test_use_custom_pricing_not_detected_litellm_metadata_no_pricing():
"""Should return False when litellm_metadata.model_info has no pricing keys."""
from litellm.litellm_core_utils.litellm_logging import \
use_custom_pricing_for_model
from litellm.litellm_core_utils.litellm_logging import use_custom_pricing_for_model
litellm_params = {
"litellm_metadata": {
@@ -190,8 +186,7 @@ def test_response_cost_calculator_uses_router_model_id_from_litellm_metadata():
does not carry _hidden_params (e.g. ResponsesAPIResponse from /v1/responses
streaming). Regression test for custom pricing on streaming responses."""
import litellm
from litellm.litellm_core_utils.litellm_logging import \
Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import ResponsesAPIResponse
custom_model_id = "gpt-5-custom-pricing"
@@ -301,8 +296,9 @@ class TestGetRouterModelId:
def test_returns_none_when_no_litellm_params(self):
"""Should return None when litellm_params is not set."""
from litellm.litellm_core_utils.litellm_logging import \
Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
obj = LiteLLMLoggingObj(
model="test",
@@ -326,10 +322,12 @@ class TestAnthropicPassthroughCustomPricing:
when the logging object carries custom pricing in model_info."""
from unittest.mock import patch
from litellm.litellm_core_utils.litellm_logging import \
Logging as LiteLLMLoggingObj
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import \
AnthropicPassthroughLoggingHandler
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
logging_obj = LiteLLMLoggingObj(
model="claude-sonnet-4-20250514",
@@ -438,7 +436,10 @@ class TestUpdateFromKwargs:
)
# kwargs metadata is preserved, caller metadata is merged in
assert logging_obj.litellm_params["metadata"] == {"from_kwargs": True, "from_caller": True}
assert logging_obj.litellm_params["metadata"] == {
"from_kwargs": True,
"from_caller": True,
}
def test_kwargs_metadata_wins_over_caller_metadata_in_conflict(self, logging_obj):
"""kwargs metadata takes precedence; caller litellm_params metadata is merged without overwriting."""
@@ -446,7 +447,10 @@ class TestUpdateFromKwargs:
logging_obj.update_from_kwargs(
kwargs=kwargs,
litellm_params={"metadata": {"from_caller": True, "shared_key": "caller_value"}, "litellm_call_id": "x"},
litellm_params={
"metadata": {"from_caller": True, "shared_key": "caller_value"},
"litellm_call_id": "x",
},
)
# kwargs metadata is preserved (shared_key keeps the kwargs value), caller-only keys are added
@@ -458,8 +462,9 @@ class TestUpdateFromKwargs:
def test_custom_pricing_detected_via_litellm_metadata(self, logging_obj):
"""Custom pricing in litellm_metadata.model_info should set custom_pricing flag."""
from litellm.litellm_core_utils.litellm_logging import \
use_custom_pricing_for_model
from litellm.litellm_core_utils.litellm_logging import (
use_custom_pricing_for_model,
)
lm_meta = {
"model_info": {
@@ -518,8 +523,7 @@ async def test_datadog_logger_not_shadowed_by_llm_obs(monkeypatch):
monkeypatch.setenv("DD_SITE", "us5.datadoghq.com")
from litellm.integrations.datadog.datadog import DataDogLogger
from litellm.integrations.datadog.datadog_llm_obs import \
DataDogLLMObsLogger
from litellm.integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger
from litellm.litellm_core_utils import litellm_logging as logging_module
logging_module._in_memory_loggers.clear()
@@ -560,8 +564,7 @@ async def test_logfire_logger_accepts_env_vars_for_base_url(monkeypatch):
) # no trailing slash on purpose
# Import after env vars are set (important if module-level caching exists)
from litellm.integrations.opentelemetry import \
OpenTelemetry # logger class
from litellm.integrations.opentelemetry import OpenTelemetry # logger class
from litellm.litellm_core_utils import litellm_logging as logging_module
logging_module._in_memory_loggers.clear()
@@ -890,8 +893,7 @@ def test_success_handler_runs_guardrail_logging_hook_when_enabled(logging_obj):
def test_get_user_agent_tags():
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
tags = StandardLoggingPayloadSetup._get_user_agent_tags(
proxy_server_request={
@@ -906,8 +908,7 @@ def test_get_user_agent_tags():
def test_get_request_tags():
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
tags = StandardLoggingPayloadSetup._get_request_tags(
litellm_params={"metadata": {"tags": ["test-tag"]}},
@@ -934,8 +935,7 @@ def test_get_request_tags_from_metadata_and_litellm_metadata():
4. No tags in either
5. None values for metadata/litellm_metadata
"""
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Test case 1: Tags in metadata only
tags = StandardLoggingPayloadSetup._get_request_tags(
@@ -1016,8 +1016,7 @@ def test_get_request_tags_does_not_mutate_original_tags():
would cause User-Agent tags to be duplicated because the function was mutating
the original tags list instead of creating a copy.
"""
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Create metadata with original tags
original_tags = ["custom-tag-1", "custom-tag-2"]
@@ -1077,8 +1076,7 @@ def test_get_request_tags_does_not_mutate_original_tags():
def test_get_extra_header_tags():
"""Test the _get_extra_header_tags method with various scenarios."""
import litellm
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Store original value to restore later
original_extra_headers = getattr(litellm, "extra_spend_tag_headers", None)
@@ -1299,17 +1297,17 @@ async def test_e2e_generate_cold_storage_object_key_successful():
from datetime import datetime, timezone
from unittest.mock import patch
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Create test data
start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc)
response_id = "chatcmpl-test-12345"
team_alias = "test-team"
with patch("litellm.cold_storage_custom_logger", return_value="s3"), patch(
"litellm.integrations.s3.get_s3_object_key"
) as mock_get_s3_key:
with (
patch("litellm.cold_storage_custom_logger", return_value="s3"),
patch("litellm.integrations.s3.get_s3_object_key") as mock_get_s3_key,
):
# Mock the S3 object key generation to return a predictable result
mock_get_s3_key.return_value = (
"2025-01-15/time-10-30-45-123456_chatcmpl-test-12345.json"
@@ -1342,8 +1340,7 @@ async def test_e2e_generate_cold_storage_object_key_with_custom_logger_s3_path()
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Create test data
start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc)
@@ -1353,11 +1350,13 @@ async def test_e2e_generate_cold_storage_object_key_with_custom_logger_s3_path()
mock_custom_logger = MagicMock()
mock_custom_logger.s3_path = "storage"
with patch("litellm.cold_storage_custom_logger", "s3_v2"), patch(
"litellm.logging_callback_manager.get_active_custom_logger_for_callback_name"
) as mock_get_logger, patch(
"litellm.integrations.s3.get_s3_object_key"
) as mock_get_s3_key:
with (
patch("litellm.cold_storage_custom_logger", "s3_v2"),
patch(
"litellm.logging_callback_manager.get_active_custom_logger_for_callback_name"
) as mock_get_logger,
patch("litellm.integrations.s3.get_s3_object_key") as mock_get_s3_key,
):
# Setup mocks
mock_get_logger.return_value = mock_custom_logger
mock_get_s3_key.return_value = (
@@ -1394,8 +1393,7 @@ async def test_e2e_generate_cold_storage_object_key_with_logger_no_s3_path():
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Create test data
start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc)
@@ -1405,11 +1403,13 @@ async def test_e2e_generate_cold_storage_object_key_with_logger_no_s3_path():
mock_custom_logger = MagicMock()
mock_custom_logger.s3_path = None # or could be missing attribute
with patch("litellm.cold_storage_custom_logger", "s3_v2"), patch(
"litellm.logging_callback_manager.get_active_custom_logger_for_callback_name"
) as mock_get_logger, patch(
"litellm.integrations.s3.get_s3_object_key"
) as mock_get_s3_key:
with (
patch("litellm.cold_storage_custom_logger", "s3_v2"),
patch(
"litellm.logging_callback_manager.get_active_custom_logger_for_callback_name"
) as mock_get_logger,
patch("litellm.integrations.s3.get_s3_object_key") as mock_get_s3_key,
):
# Setup mocks
mock_get_logger.return_value = mock_custom_logger
mock_get_s3_key.return_value = (
@@ -1442,8 +1442,7 @@ async def test_e2e_generate_cold_storage_object_key_not_configured():
from unittest.mock import patch
import litellm
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Create test data
start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc)
@@ -1467,8 +1466,7 @@ def test_get_final_response_obj_with_empty_response_obj_and_list_init():
When response_obj is empty (falsy), the method should return init_response_obj if it's a list.
"""
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Create test objects
class TestObject1:
@@ -1504,8 +1502,7 @@ def test_get_usage_as_dict():
"""
Test get_usage_as_dict returns usage as plain dict from response_obj or combined_usage_object.
"""
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
from litellm.types.utils import Usage
# Test case 1: None response_obj returns empty usage dict
@@ -1543,8 +1540,7 @@ def test_append_system_prompt_messages():
"""
Test append_system_prompt_messages prepends system message from kwargs to messages list.
"""
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Test case 1: system in kwargs with existing messages
kwargs = {"system": "You are a helpful assistant"}
@@ -1615,8 +1611,7 @@ async def test_async_success_handler_sets_standard_logging_object_for_pass_throu
from datetime import datetime
from unittest.mock import patch
from litellm.litellm_core_utils.litellm_logging import \
Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import StandardPassThroughResponseObject
# Create a logging object for a pass-through endpoint
@@ -1697,8 +1692,7 @@ async def test_async_success_handler_prevents_reprocessing_for_pass_through_endp
from datetime import datetime
from unittest.mock import patch
from litellm.litellm_core_utils.litellm_logging import \
Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import StandardPassThroughResponseObject
# Create a logging object for a pass-through endpoint
@@ -1774,8 +1768,7 @@ async def test_async_success_handler_sets_standard_logging_object_for_streaming_
from datetime import datetime
from unittest.mock import patch
from litellm.litellm_core_utils.litellm_logging import \
Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import StandardPassThroughResponseObject
# Create a logging object for a streaming pass-through endpoint
@@ -1831,8 +1824,7 @@ def test_get_error_information_error_code_priority():
Test get_error_information prioritizes 'code' attribute over 'status_code' attribute
and handles edge cases like empty strings and "None" string values.
"""
from litellm.litellm_core_utils.litellm_logging import \
StandardLoggingPayloadSetup
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Test case 1: Exception with 'code' attribute (ProxyException style)
class ProxyException(Exception):
@@ -2025,8 +2017,7 @@ async def test_async_success_handler_preserves_response_cost_for_pass_through_en
by pass-through handlers (Gemini/Vertex)."""
from datetime import datetime
from litellm.litellm_core_utils.litellm_logging import \
Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import ModelResponse, Usage
logging_obj = LiteLLMLoggingObj(
@@ -2366,6 +2357,56 @@ def test_merge_hidden_params_from_response_into_metadata_no_op_when_empty():
_hidden_params = {}
logging_obj._merge_hidden_params_from_response_into_metadata(_NoHp())
assert "hidden_params" not in logging_obj.model_call_details["litellm_params"][
"metadata"
]
assert (
"hidden_params"
not in logging_obj.model_call_details["litellm_params"]["metadata"]
)
# ── StandardLoggingPayloadSetup.get_additional_headers ───────────────────────
def test_get_additional_headers_preserves_provider_request_id():
"""llm_provider-x-request-id must survive the get_additional_headers filter."""
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
raw = {
"x-ratelimit-remaining-requests": "29999",
"x-ratelimit-remaining-tokens": "149999970",
"llm_provider-x-request-id": "req_85f49b546c7b4d3180755621f36631a1",
"llm_provider-openai-organization": "my-org",
"llm_provider-openai-processing-ms": "649",
}
result = StandardLoggingPayloadSetup.get_additional_headers(raw)
assert result is not None
# well-known fields parsed as ints
assert result["x_ratelimit_remaining_requests"] == 29999 # type: ignore
assert result["x_ratelimit_remaining_tokens"] == 149999970 # type: ignore
# provider-specific headers must be preserved verbatim
assert result["llm_provider-x-request-id"] == "req_85f49b546c7b4d3180755621f36631a1" # type: ignore
assert result["llm_provider-openai-organization"] == "my-org" # type: ignore
assert result["llm_provider-openai-processing-ms"] == "649" # type: ignore
def test_get_additional_headers_returns_none_for_none_input():
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
assert StandardLoggingPayloadSetup.get_additional_headers(None) is None
def test_get_additional_headers_reset_fields_preserved():
"""x-ratelimit-reset-* fields (added to the TypedDict) must be captured."""
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
raw = {
"x-ratelimit-reset-requests": "1s",
"x-ratelimit-reset-tokens": "100ms",
}
result = StandardLoggingPayloadSetup.get_additional_headers(raw)
assert result is not None
assert result["x_ratelimit_reset_requests"] == "1s" # type: ignore
assert result["x_ratelimit_reset_tokens"] == "100ms" # type: ignore
@@ -440,14 +440,18 @@ class TestProxyOAuthHeaderForwarding:
(b"content-type", b"application/json"),
]
)
# Should preserve OAuth even with flag=False
cleaned_without_flag = clean_headers(raw_headers, forward_llm_provider_auth_headers=False)
cleaned_without_flag = clean_headers(
raw_headers, forward_llm_provider_auth_headers=False
)
assert "authorization" in cleaned_without_flag
assert cleaned_without_flag["authorization"] == f"Bearer {FAKE_OAUTH_TOKEN}"
# Should also preserve OAuth with flag=True
cleaned_with_flag = clean_headers(raw_headers, forward_llm_provider_auth_headers=True)
cleaned_with_flag = clean_headers(
raw_headers, forward_llm_provider_auth_headers=True
)
assert "authorization" in cleaned_with_flag
assert cleaned_with_flag["authorization"] == f"Bearer {FAKE_OAUTH_TOKEN}"
@@ -867,8 +871,6 @@ class TestValidateEnvironmentAuthToken:
assert "authorization" not in headers
class TestGetAuthToken:
"""Tests for AnthropicModelInfo.get_auth_token() static method."""
@@ -1092,7 +1094,10 @@ class TestPassthroughAuthToken:
config = AnthropicMessagesConfig()
with mock_patch.dict(
"os.environ",
{"ANTHROPIC_API_KEY": FAKE_REGULAR_KEY, "ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN},
{
"ANTHROPIC_API_KEY": FAKE_REGULAR_KEY,
"ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN,
},
clear=True,
):
updated_headers, _ = config.validate_anthropic_messages_environment(
@@ -1131,3 +1136,147 @@ class TestPassthroughAuthToken:
)
assert url == "https://custom.example.com/v1/messages"
class TestAnthropicThinkingSignatureSelfHeal:
"""Helpers for retrying after invalid encrypted thinking signatures."""
def test_is_anthropic_invalid_thinking_signature_error_positive(self):
from litellm.llms.anthropic.common_utils import (
is_anthropic_invalid_thinking_signature_error,
)
raw = (
'{"type":"error","error":{"type":"invalid_request_error",'
'"message":"messages.3.content.3: Invalid `signature` in `thinking` block"},'
'"request_id":"req_011Ca2EtQDxp7x6RGUY2jVn9"}'
)
assert is_anthropic_invalid_thinking_signature_error(raw) is True
def test_is_anthropic_invalid_thinking_signature_error_negative(self):
from litellm.llms.anthropic.common_utils import (
is_anthropic_invalid_thinking_signature_error,
)
assert is_anthropic_invalid_thinking_signature_error("") is False
assert (
is_anthropic_invalid_thinking_signature_error("rate limit exceeded")
is False
)
def test_strip_thinking_blocks_from_anthropic_messages(self):
from litellm.llms.anthropic.common_utils import (
strip_thinking_blocks_from_anthropic_messages,
)
messages = [
{"role": "user", "content": "hi"},
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "plan", "signature": "sig"},
{"type": "text", "text": "hello"},
],
},
]
out = strip_thinking_blocks_from_anthropic_messages(messages)
assert len(out) == 2
assert out[0] == messages[0]
assert len(out[1]["content"]) == 1
assert out[1]["content"][0]["type"] == "text"
assert messages[1]["content"][0]["type"] == "thinking"
def test_strip_thinking_blocks_drops_message_when_only_thinking_blocks(self):
from litellm.llms.anthropic.common_utils import (
strip_thinking_blocks_from_anthropic_messages,
)
messages = [
{"role": "user", "content": "hi"},
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "plan", "signature": "sig"},
],
},
]
out = strip_thinking_blocks_from_anthropic_messages(messages)
assert len(out) == 1
assert out[0]["role"] == "user"
def test_strip_thinking_blocks_from_anthropic_messages_request_dict(self):
from litellm.llms.anthropic.common_utils import (
strip_thinking_blocks_from_anthropic_messages_request_dict,
)
data = {
"model": "claude-sonnet-4-20250514",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "thinking",
"thinking": "x",
"signature": "y",
},
],
}
],
"thinking": {"type": "enabled", "budget_tokens": 1024},
}
strip_thinking_blocks_from_anthropic_messages_request_dict(data)
assert "thinking" not in data
assert data["messages"] == []
def test_anthropic_messages_config_http_retry_helpers(self):
import httpx
from litellm.llms.anthropic.experimental_pass_through.messages.transformation import (
AnthropicMessagesConfig,
)
config = AnthropicMessagesConfig()
assert config.max_retry_on_anthropic_messages_http_error == 2
req = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
err_text = (
'{"type":"error","error":{"type":"invalid_request_error",'
'"message":"messages.3.content.3: Invalid `signature` in `thinking` block"},'
'"request_id":"req_011Ca2EtQDxp7x6RGUY2jVn9"}'
)
resp = httpx.Response(400, request=req, text=err_text)
err = httpx.HTTPStatusError("bad", request=req, response=resp)
assert config.should_retry_anthropic_messages_on_http_error(err, {}) is True
resp_bad = httpx.Response(400, request=req, text="rate limit exceeded")
err_bad = httpx.HTTPStatusError("bad", request=req, response=resp_bad)
assert (
config.should_retry_anthropic_messages_on_http_error(err_bad, {}) is False
)
resp_500 = httpx.Response(500, request=req, text=err_text)
err_500 = httpx.HTTPStatusError("bad", request=req, response=resp_500)
assert (
config.should_retry_anthropic_messages_on_http_error(err_500, {}) is False
)
data = {
"model": "claude-sonnet-4-20250514",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "thinking",
"thinking": "x",
"signature": "y",
},
],
}
],
"thinking": {"type": "enabled", "budget_tokens": 1024},
}
config.transform_anthropic_messages_request_on_http_error(err, data)
assert "thinking" not in data
assert data["messages"] == []
@@ -0,0 +1,97 @@
import json
import os
import sys
from unittest.mock import MagicMock
import httpx
sys.path.insert(0, os.path.abspath("../../../../.."))
from litellm.llms.azure.passthrough.transformation import AzurePassthroughConfig
from litellm.types.utils import ModelResponse
def _azure_chat_completion_body():
return {
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1700000000,
"model": "gpt-4.1-mini-2025-04-14",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I assist you today?",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 8,
"total_tokens": 18,
},
}
def _make_httpx_response(body: dict) -> httpx.Response:
return httpx.Response(
status_code=200,
headers={"content-type": "application/json"},
content=json.dumps(body).encode("utf-8"),
request=httpx.Request(
"POST",
"https://example.openai.azure.com/openai/deployments/gpt-4.1-mini/chat/completions",
),
)
def test_azure_passthrough_logging_non_streaming_response_chat_completions():
"""
Returns a populated ModelResponse (with usage + content) for a chat/completions
endpoint. This is what _success_handler_helper_fn needs to build
standard_logging_object without it, Datadog/cost-tracking/router-success all
raise on every Azure passthrough request.
"""
config = AzurePassthroughConfig()
logging_obj = MagicMock()
result = config.logging_non_streaming_response(
model="gpt-4.1-mini",
custom_llm_provider="azure",
httpx_response=_make_httpx_response(_azure_chat_completion_body()),
request_data={
"model": "gpt-4.1-mini",
"messages": [{"role": "user", "content": "hi"}],
},
logging_obj=logging_obj,
endpoint="openai/deployments/gpt-4.1-mini/chat/completions",
)
assert isinstance(result, ModelResponse)
assert result.choices[0].message.content == "Hello! How can I assist you today?"
assert result.usage.prompt_tokens == 10
assert result.usage.completion_tokens == 8
assert result.usage.total_tokens == 18
def test_azure_passthrough_logging_non_streaming_response_unknown_endpoint_returns_none():
"""
Endpoints other than chat/completions (responses, messages, images) fall
through to None matches base-class behavior and Bedrock's "unknown
endpoint" handling. Not a regression; just scoping.
"""
config = AzurePassthroughConfig()
logging_obj = MagicMock()
result = config.logging_non_streaming_response(
model="gpt-4.1-mini",
custom_llm_provider="azure",
httpx_response=_make_httpx_response(_azure_chat_completion_body()),
request_data={},
logging_obj=logging_obj,
endpoint="openai/responses",
)
assert result is None
@@ -222,5 +222,7 @@ class TestAzureAnthropicChatCompletion:
# Verify non-streaming was handled
mock_client.post.assert_called_once()
mock_get_client.assert_called_once_with(params={"timeout": timeout})
assert mock_client.post.call_args.kwargs["timeout"] == timeout
assert result is not None
@@ -0,0 +1,42 @@
"""
Ensure litellm.completion() forwards timeout to Azure Anthropic handler (main.py dispatch).
"""
import os
import sys
from unittest.mock import MagicMock, patch
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
)
from litellm import completion
from litellm.types.utils import ModelResponse
def test_main_azure_ai_claude_completion_passes_timeout_to_azure_anthropic_handler():
captured: dict = {}
def fake_azure_anthropic_completion(**kwargs):
captured.update(kwargs)
return ModelResponse()
with patch(
"litellm.main.azure_anthropic_chat_completions"
) as mock_azure_anthropic:
mock_azure_anthropic.completion = MagicMock(
side_effect=fake_azure_anthropic_completion
)
completion(
model="azure_ai/claude-sonnet-4-5",
messages=[{"role": "user", "content": "hi"}],
api_base="https://example.services.ai.azure.com/anthropic",
api_key="test-key",
timeout=42.5,
)
mock_azure_anthropic.completion.assert_called_once()
assert captured["timeout"] == 42.5
assert captured["model"] == "claude-sonnet-4-5"
assert captured["custom_llm_provider"] == "azure_ai"
@@ -3140,9 +3140,14 @@ def test_add_additional_properties_definitions():
assert result["definitions"]["Item"]["properties"]["details"]["additionalProperties"] is False
def test_json_object_no_schema_falls_back_to_tool_call():
"""response_format: {type: json_object} with no schema should use tool-call fallback,
even for models that support native structured outputs."""
def test_json_object_no_schema_skips_tool_injection():
"""response_format: {type: json_object} with no schema should NOT inject
the synthetic json_tool_call tool.
When no schema is given, _create_json_tool_call_for_response_format builds
a tool with an empty schema (properties: {}). The model follows the schema
and returns {} instead of the requested JSON. Skipping tool injection lets
the model respond naturally with the JSON the caller asked for."""
old_env = os.environ.get("LITELLM_LOCAL_MODEL_COST_MAP")
old_cost = litellm.model_cost
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
@@ -3162,8 +3167,9 @@ def test_json_object_no_schema_falls_back_to_tool_call():
# Should NOT use native outputConfig (no schema provided)
assert "outputConfig" not in result
# Should use tool-call fallback
assert "tools" in result
# Should NOT inject tools - empty schema causes model to return {}
assert "tools" not in result
assert "tool_choice" not in result
assert result["json_mode"] is True
finally:
litellm.model_cost = old_cost
@@ -3655,3 +3661,140 @@ def test_cache_control_injection_tool_config_not_added_without_injection_point()
tools = result["toolConfig"]["tools"]
# No cachePoint should be appended
assert all("cachePoint" not in tool for tool in tools)
def test_translate_response_format_json_schema_still_injects_tool():
"""
response_format with an explicit json_schema should still use the
synthetic tool call approach (for models that don't support native
structured outputs).
"""
config = AmazonConverseConfig()
response_format = {
"type": "json_schema",
"json_schema": {
"name": "FactResult",
"schema": {
"type": "object",
"properties": {
"facts": {
"type": "array",
"items": {"type": "string"},
},
},
"required": ["facts"],
},
},
}
optional_params: dict = {}
result = config._translate_response_format_param(
value=response_format,
model="anthropic.claude-3-haiku-20240307-v1:0",
optional_params=optional_params,
non_default_params={"response_format": response_format},
is_thinking_enabled=False,
)
assert result["json_mode"] is True
assert "tools" in result
assert "tool_choice" in result
def test_transform_response_finish_reason_stop_when_json_mode_filters_all_tools():
"""
When json_mode is True and _filter_json_mode_tools strips all synthetic
tool calls, finish_reason should be "stop", not "tool_calls".
Bedrock returns stopReason="tool_use" for json_tool_call responses.
After filtering, the response is plain JSON content and should not look
like a pending tool invocation to callers.
"""
from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig
from litellm.types.utils import ModelResponse
response_json = {
"metrics": {"latencyMs": 100},
"output": {
"message": {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tooluse_001",
"name": "json_tool_call",
"input": {
"facts": ["Bob is a software engineer"],
},
}
}
],
}
},
"stopReason": "tool_use",
"usage": {
"inputTokens": 50,
"outputTokens": 20,
"totalTokens": 70,
"cacheReadInputTokenCount": 0,
"cacheReadInputTokens": 0,
"cacheWriteInputTokenCount": 0,
"cacheWriteInputTokens": 0,
},
}
class MockResponse:
def json(self):
return response_json
@property
def text(self):
return json.dumps(response_json)
config = AmazonConverseConfig()
model_response = ModelResponse()
# Simulate what happens when json_tool_call was injected for a
# json_schema request: optional_params has the synthetic tool
optional_params = {
"json_mode": True,
"tools": [
{
"type": "function",
"function": {
"name": "json_tool_call",
"parameters": {
"type": "object",
"additionalProperties": True,
"properties": {},
},
},
}
],
}
result = config._transform_response(
model="bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0",
response=MockResponse(),
model_response=model_response,
stream=False,
logging_obj=None,
optional_params=optional_params,
api_key=None,
data=None,
messages=[],
encoding=None,
)
# Content should have the JSON from the tool call arguments
content = result.choices[0].message.content
assert content is not None
parsed = json.loads(content)
assert parsed["facts"] == ["Bob is a software engineer"]
# No tool_calls on the message
assert result.choices[0].message.tool_calls is None
# finish_reason must be "stop", not "tool_calls"
assert result.choices[0].finish_reason == "stop"
@@ -15,7 +15,12 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
from litellm.llms.custom_httpx.aiohttp_transport import LiteLLMAiohttpTransport
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, get_ssl_configuration
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_ssl_configuration,
)
@pytest.mark.asyncio
@@ -658,3 +663,26 @@ async def test_httpx_handler_uses_env_user_agent(monkeypatch):
assert req.headers.get("User-Agent") == "Claude Code"
finally:
await handler.close()
def test_get_httpx_client_applies_float_timeout_without_mocking_handler():
"""
Exercise real _get_httpx_client + HTTPHandler: params={'timeout': x} must reach httpx.Client(timeout=...).
Uses an uncommon timeout value to avoid colliding with other cached clients in-process.
"""
timeout = 3847.291
handler = _get_httpx_client(params={"timeout": timeout})
try:
assert isinstance(handler, HTTPHandler)
assert handler.client.timeout == httpx.Timeout(timeout)
finally:
handler.close()
def test_get_httpx_client_applies_httpx_timeout_object_without_mocking_handler():
t = httpx.Timeout(40.0, connect=5.0)
handler = _get_httpx_client(params={"timeout": t})
try:
assert handler.client.timeout == t
finally:
handler.close()
@@ -15,6 +15,12 @@ from litellm.llms.ollama.chat.transformation import OllamaChatConfig, OllamaChat
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import get_optional_params
import json
from unittest.mock import MagicMock
import litellm
from litellm.types.utils import Choices, Message, ModelResponse
class TestEvent(BaseModel):
name: str
@@ -427,12 +433,6 @@ class TestOllamaToolCalling:
def test_finish_reason_stop_when_no_tool_calls(self):
"""Test that finish_reason remains 'stop' when no tool_calls present."""
import json
from unittest.mock import MagicMock
import litellm
from litellm.types.utils import Choices, Message, ModelResponse
config = OllamaChatConfig()
# Simulated Ollama response without tool_calls
@@ -476,6 +476,140 @@ class TestOllamaToolCalling:
assert result.choices[0].message.tool_calls is None
class TestOllamaFinishReasonLength:
"""Tests for done_reason 'length' → finish_reason 'length' mapping.
Ollama returns done_reason='length' when a response is truncated by num_predict
(max_tokens). Previously finish_reason was hardcoded to 'stop', hiding truncation.
The Anthropic pass-through adapter then maps OpenAI 'length' 'max_tokens'.
"""
def test_finish_reason_length_non_streaming(self):
"""Non-streaming: done_reason='length' must propagate as finish_reason='length'."""
config = OllamaChatConfig()
ollama_response = {
"model": "qwen3:2b",
"created_at": "2025-01-11T00:00:00.000000Z",
"message": {
"role": "assistant",
"content": "A neural network learns through",
},
"done": True,
"done_reason": "length",
"prompt_eval_count": 20,
"eval_count": 20,
}
mock_response = MagicMock()
mock_response.json.return_value = ollama_response
mock_response.text = json.dumps(ollama_response)
mock_logging = MagicMock()
model_response = ModelResponse()
model_response.choices = [Choices(message=Message(content=""), index=0)]
result = config.transform_response(
model="qwen3:2b",
raw_response=mock_response,
model_response=model_response,
logging_obj=mock_logging,
request_data={},
messages=[{"role": "user", "content": "Explain neural networks."}],
optional_params={},
litellm_params={},
encoding=None,
api_key=None,
json_mode=False,
)
assert result.choices[0].finish_reason == "length", (
f"Expected 'length' when done_reason='length', got '{result.choices[0].finish_reason}'"
)
def test_finish_reason_stop_non_streaming(self):
"""Non-streaming: done_reason='stop' (natural finish) must stay 'stop'."""
config = OllamaChatConfig()
ollama_response = {
"model": "qwen3:2b",
"created_at": "2025-01-11T00:00:00.000000Z",
"message": {"role": "assistant", "content": "2 + 2 = 4."},
"done": True,
"done_reason": "stop",
"prompt_eval_count": 10,
"eval_count": 8,
}
mock_response = MagicMock()
mock_response.json.return_value = ollama_response
mock_response.text = json.dumps(ollama_response)
mock_logging = MagicMock()
model_response = ModelResponse()
model_response.choices = [Choices(message=Message(content=""), index=0)]
result = config.transform_response(
model="qwen3:2b",
raw_response=mock_response,
model_response=model_response,
logging_obj=mock_logging,
request_data={},
messages=[{"role": "user", "content": "What is 2+2?"}],
optional_params={},
litellm_params={},
encoding=None,
api_key=None,
json_mode=False,
)
assert result.choices[0].finish_reason == "stop", (
f"Expected 'stop' for natural finish, got '{result.choices[0].finish_reason}'"
)
def test_finish_reason_length_streaming(self):
"""Streaming: done_reason='length' in final chunk must produce finish_reason='length'."""
iterator = OllamaChatCompletionResponseIterator(
streaming_response=iter([]),
sync_stream=True,
)
done_chunk = {
"model": "qwen3:2b",
"message": {"role": "assistant", "content": "A neural network learns through"},
"done": True,
"done_reason": "length",
}
result = iterator.chunk_parser(done_chunk)
assert result.choices[0].finish_reason == "length", (
f"Expected 'length' when done_reason='length', got '{result.choices[0].finish_reason}'"
)
def test_finish_reason_stop_streaming(self):
"""Streaming: done_reason='stop' in final chunk must produce finish_reason='stop'."""
iterator = OllamaChatCompletionResponseIterator(
streaming_response=iter([]),
sync_stream=True,
)
done_chunk = {
"model": "qwen3:2b",
"message": {"role": "assistant", "content": "2 + 2 = 4."},
"done": True,
"done_reason": "stop",
}
result = iterator.chunk_parser(done_chunk)
assert result.choices[0].finish_reason == "stop", (
f"Expected 'stop' for natural finish, got '{result.choices[0].finish_reason}'"
)
class TestOllamaReasoningContentStreaming:
"""Test that reasoning_content is properly extracted from all thinking chunks."""
@@ -162,3 +162,69 @@ class TestCountTokensLocationResolution:
)
assert captured["vertex_location"] == "asia-southeast1"
class TestCountTokensVersionSuffixStripping:
"""Verify that version suffixes (@default, @20251001, etc.) are stripped
from model names before sending to the Vertex AI count-tokens endpoint.
The Vertex AI count-tokens API rejects versioned model names with:
"claude-sonnet-4-6@default is not supported for token counting"
while "claude-sonnet-4-6" (without suffix) works correctly.
"""
def test_strip_version_suffix_at_default(self):
counter = VertexAIPartnerModelsTokenCounter()
assert counter._strip_version_suffix("claude-sonnet-4-6@default") == "claude-sonnet-4-6"
def test_strip_version_suffix_at_date(self):
counter = VertexAIPartnerModelsTokenCounter()
assert counter._strip_version_suffix("claude-haiku-4-5@20251001") == "claude-haiku-4-5"
def test_strip_version_suffix_no_suffix(self):
counter = VertexAIPartnerModelsTokenCounter()
assert counter._strip_version_suffix("claude-sonnet-4-6") == "claude-sonnet-4-6"
@pytest.mark.asyncio
async def test_handle_count_tokens_strips_version_from_request_data(self, monkeypatch):
"""The model name in request_data sent to the API must have @suffix stripped."""
counter = VertexAIPartnerModelsTokenCounter()
captured_json = {}
async def fake_ensure_access_token(self, credentials, project_id, custom_llm_provider):
return "fake-token", "fake-project"
def fake_build_endpoint(self, model, project_id, vertex_location, api_base=None):
return "https://fake-endpoint"
monkeypatch.setattr(
VertexAIPartnerModelsTokenCounter, "_ensure_access_token_async", fake_ensure_access_token
)
monkeypatch.setattr(
VertexAIPartnerModelsTokenCounter, "_build_count_tokens_endpoint", fake_build_endpoint
)
class FakeResponse:
status_code = 200
def json(self):
return {"input_tokens": 10}
class FakeClient:
async def post(self, url, headers=None, json=None, **kwargs):
captured_json.update(json or {})
return FakeResponse()
import litellm.llms.vertex_ai.vertex_ai_partner_models.count_tokens.handler as handler_mod
monkeypatch.setattr(handler_mod, "get_async_httpx_client", lambda **kwargs: FakeClient())
await counter.handle_count_tokens_request(
model="claude-sonnet-4-6@default",
request_data={
"model": "claude-sonnet-4-6@default",
"messages": [{"role": "user", "content": "hi"}],
},
litellm_params={"vertex_location": "us-east5"},
)
# The model name sent to the API must NOT have the @default suffix
assert captured_json["model"] == "claude-sonnet-4-6"
@@ -5,7 +5,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from mcp import ReadResourceResult, Resource
from mcp.types import BlobResourceContents, Prompt, ResourceTemplate, TextResourceContents
from mcp.types import (
BlobResourceContents,
Prompt,
ResourceTemplate,
TextResourceContents,
)
from litellm.proxy._types import (
LiteLLM_MCPServerTable,
@@ -157,15 +162,19 @@ async def test_get_prompts_from_mcp_servers_success():
server_b.auth_type = None
server_b.extra_headers = None
with patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[server_a, server_b]),
) as mock_allowed, patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=(None, None),
) as mock_headers, patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager:
with (
patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[server_a, server_b]),
) as mock_allowed,
patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=(None, None),
) as mock_headers,
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager,
):
mock_manager.get_prompts_from_server = AsyncMock(
side_effect=[
[Prompt(name="hello", description="hi")],
@@ -213,15 +222,19 @@ async def test_get_resources_from_mcp_servers_success():
server_b.auth_type = None
server_b.extra_headers = None
with patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[server_a, server_b]),
) as mock_allowed, patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=(None, None),
) as mock_headers, patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager:
with (
patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[server_a, server_b]),
) as mock_allowed,
patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=(None, None),
) as mock_headers,
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager,
):
mock_manager.get_resources_from_server = AsyncMock(
side_effect=[
[
@@ -274,15 +287,19 @@ async def test_get_resource_templates_from_mcp_servers_success():
server.auth_type = None
server.extra_headers = None
with patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[server]),
) as mock_allowed, patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=(None, None),
) as mock_headers, patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager:
with (
patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[server]),
) as mock_allowed,
patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=(None, None),
) as mock_headers,
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager,
):
mock_manager.get_resource_templates_from_server = AsyncMock(
return_value=[
ResourceTemplate(
@@ -320,15 +337,19 @@ async def test_mcp_get_prompt_success():
prompt_result = MagicMock(name="prompt_result")
with patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[server]),
) as mock_allowed, patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=({"Authorization": "token"}, {"X-Test": "1"}),
) as mock_headers, patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager:
with (
patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[server]),
) as mock_allowed,
patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=({"Authorization": "token"}, {"X-Test": "1"}),
) as mock_headers,
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager,
):
mock_manager.get_prompt_from_server = AsyncMock(return_value=prompt_result)
result = await mcp_get_prompt(
@@ -378,15 +399,19 @@ async def test_mcp_read_resource_success():
]
)
with patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[server]),
) as mock_allowed, patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=({"Authorization": "token"}, {"X-Test": "1"}),
) as mock_headers, patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager:
with (
patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[server]),
) as mock_allowed,
patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=({"Authorization": "token"}, {"X-Test": "1"}),
) as mock_headers,
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager,
):
mock_manager.read_resource_from_server = AsyncMock(return_value=read_result)
result = await mcp_read_resource(
@@ -591,7 +616,10 @@ async def test_get_tools_from_mcp_servers_continues_when_one_server_fails():
working_server if server_id == "working_server" else failing_server
)
# Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (
server_ids,
0,
)
async def mock_get_tools_from_server(
server,
@@ -693,7 +721,10 @@ async def test_get_tools_from_mcp_servers_handles_all_servers_failing():
failing_server1 if server_id == "failing_server1" else failing_server2
)
# Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (
server_ids,
0,
)
async def mock_get_tools_from_server(
server,
@@ -830,12 +861,14 @@ async def test_concurrent_initialize_session_managers():
mcp_server._sse_session_manager_cm = None
# Mock the session managers to avoid actual MCP initialization
with patch(
"litellm.proxy._experimental.mcp_server.server.session_manager"
) as mock_session_manager, patch(
"litellm.proxy._experimental.mcp_server.server.sse_session_manager"
) as mock_sse_session_manager, patch(
"litellm.proxy._experimental.mcp_server.server.verbose_logger"
with (
patch(
"litellm.proxy._experimental.mcp_server.server.session_manager"
) as mock_session_manager,
patch(
"litellm.proxy._experimental.mcp_server.server.sse_session_manager"
) as mock_sse_session_manager,
patch("litellm.proxy._experimental.mcp_server.server.verbose_logger"),
):
# Mock the run() method to return a mock context manager
mock_cm = AsyncMock()
@@ -961,15 +994,19 @@ async def test_mcp_routing_with_conflicting_alias_and_group_name():
return_value=[specific_server.server_id, other_server.server_id]
)
with patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_allowed_mcp_servers",
mock_get_allowed,
), patch(
"litellm.proxy._experimental.mcp_server.server.MCPRequestHandler._get_mcp_servers_from_access_groups",
mock_db_lookup,
), patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager._get_tools_from_server",
mock_get_tools_spy,
with (
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_allowed_mcp_servers",
mock_get_allowed,
),
patch(
"litellm.proxy._experimental.mcp_server.server.MCPRequestHandler._get_mcp_servers_from_access_groups",
mock_db_lookup,
),
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager._get_tools_from_server",
mock_get_tools_spy,
),
):
mcp_servers_from_path = _get_mcp_servers_in_path(test_path)
@@ -1062,17 +1099,21 @@ async def test_oauth2_headers_passed_to_mcp_client():
async def mock_fetch_tools_with_timeout(client, server_name):
return [] # Return empty list of tools
with patch.object(
global_mcp_server_manager,
"_create_mcp_client",
side_effect=mock_create_mcp_client,
) as mock_create_client, patch.object(
global_mcp_server_manager,
"_fetch_tools_with_timeout",
side_effect=mock_fetch_tools_with_timeout,
), patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[oauth2_server]),
with (
patch.object(
global_mcp_server_manager,
"_create_mcp_client",
side_effect=mock_create_mcp_client,
) as mock_create_client,
patch.object(
global_mcp_server_manager,
"_fetch_tools_with_timeout",
side_effect=mock_fetch_tools_with_timeout,
),
patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
AsyncMock(return_value=[oauth2_server]),
),
):
# Call _get_tools_from_mcp_servers which should eventually call _create_mcp_client
await _get_tools_from_mcp_servers(
@@ -1138,7 +1179,10 @@ async def test_list_tools_single_server_unprefixed_names():
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"])
mock_manager.get_mcp_server_by_id = MagicMock(return_value=server)
# Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (
server_ids,
0,
)
async def mock_get_tools_from_server(
server,
@@ -1216,7 +1260,10 @@ async def test_list_tools_multiple_servers_prefixed_names():
server1 if server_id == "server1" else server2
)
# Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (
server_ids,
0,
)
async def mock_get_tools_from_server(
server,
@@ -1270,12 +1317,15 @@ async def test_mcp_manager_allows_public_servers_without_permissions():
)
manager.registry = {public_server.server_id: public_server}
with patch(
"litellm.proxy.management_endpoints.common_utils._user_has_admin_view",
return_value=False,
), patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers",
AsyncMock(return_value=[]),
with (
patch(
"litellm.proxy.management_endpoints.common_utils._user_has_admin_view",
return_value=False,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers",
AsyncMock(return_value=[]),
),
):
allowed = await manager.get_allowed_mcp_servers(UserAPIKeyAuth())
@@ -1302,12 +1352,15 @@ async def test_mcp_manager_returns_public_when_permission_lookup_fails():
)
manager.registry = {public_server.server_id: public_server}
with patch(
"litellm.proxy.management_endpoints.common_utils._user_has_admin_view",
return_value=False,
), patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers",
AsyncMock(side_effect=Exception("boom")),
with (
patch(
"litellm.proxy.management_endpoints.common_utils._user_has_admin_view",
return_value=False,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers",
AsyncMock(side_effect=Exception("boom")),
),
):
allowed = await manager.get_allowed_mcp_servers(UserAPIKeyAuth())
@@ -1342,12 +1395,15 @@ async def test_mcp_manager_merges_public_and_restricted_servers():
scoped_server.server_id: scoped_server,
}
with patch(
"litellm.proxy.management_endpoints.common_utils._user_has_admin_view",
return_value=False,
), patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers",
AsyncMock(return_value=["restricted"]),
with (
patch(
"litellm.proxy.management_endpoints.common_utils._user_has_admin_view",
return_value=False,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers",
AsyncMock(return_value=["restricted"]),
),
):
allowed = await manager.get_allowed_mcp_servers(UserAPIKeyAuth())
@@ -1399,12 +1455,15 @@ async def test_call_mcp_tool_user_unauthorized_access():
return another_server_obj
return None
with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.MCPRequestHandler.get_allowed_mcp_servers",
AsyncMock(return_value=["allowed_server", "another_server"]),
), patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_id",
side_effect=mock_get_server_by_id,
with (
patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.MCPRequestHandler.get_allowed_mcp_servers",
AsyncMock(return_value=["allowed_server", "another_server"]),
),
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_id",
side_effect=mock_get_server_by_id,
),
):
# Try to call a tool from "restricted_server" - should raise HTTPException with 403 status
with pytest.raises(HTTPException) as exc_info:
@@ -1467,7 +1526,10 @@ async def test_list_tools_filters_by_key_team_permissions():
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"])
mock_manager.get_mcp_server_by_id = lambda server_id: server
# Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (
server_ids,
0,
)
async def mock_get_tools_from_server(
server,
@@ -1573,7 +1635,10 @@ async def test_list_tools_with_team_tool_permissions_inheritance():
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"])
mock_manager.get_mcp_server_by_id = lambda server_id: server
# Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (
server_ids,
0,
)
async def mock_get_tools_from_server(
server,
@@ -1665,7 +1730,10 @@ async def test_list_tools_with_no_tool_permissions_shows_all():
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"])
mock_manager.get_mcp_server_by_id = lambda server_id: server
# Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (
server_ids,
0,
)
async def mock_get_tools_from_server(
server,
@@ -1760,7 +1828,10 @@ async def test_list_tools_strips_prefix_when_matching_permissions():
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["gitmcp_server"])
mock_manager.get_mcp_server_by_id = MagicMock(return_value=server)
# Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0)
mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (
server_ids,
0,
)
async def mock_get_tools_from_server(
server,
@@ -2002,12 +2073,15 @@ class TestMCPServerManagerReload:
mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock(
return_value=[db_row]
)
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw",
return_value=mock_prisma,
), patch.object(
manager, "build_mcp_server_from_table", AsyncMock()
) as mock_build:
with (
patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw",
return_value=mock_prisma,
),
patch.object(
manager, "build_mcp_server_from_table", AsyncMock()
) as mock_build,
):
await manager.reload_servers_from_database()
mock_build.assert_not_awaited()
@@ -2045,14 +2119,17 @@ class TestMCPServerManagerReload:
mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock(
return_value=[db_row]
)
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw",
return_value=mock_prisma,
), patch.object(
manager,
"build_mcp_server_from_table",
AsyncMock(return_value=rebuilt_server),
) as mock_build:
with (
patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw",
return_value=mock_prisma,
),
patch.object(
manager,
"build_mcp_server_from_table",
AsyncMock(return_value=rebuilt_server),
) as mock_build,
):
await manager.reload_servers_from_database()
mock_build.assert_awaited_once_with(db_row)
@@ -2090,26 +2167,32 @@ async def test_call_mcp_tool_logs_failure_via_post_call_failure_hook():
user_auth = UserAPIKeyAuth(api_key="test-key", user_id="test-user")
with patch.object(
global_mcp_server_manager,
"get_allowed_mcp_servers",
new_callable=AsyncMock,
return_value=[mock_server.server_id],
), patch.object(
global_mcp_server_manager,
"get_mcp_server_by_id",
return_value=mock_server,
), patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers_from_mcp_server_names",
new_callable=AsyncMock,
return_value=[mock_server],
), patch(
"litellm.proxy._experimental.mcp_server.server.execute_mcp_tool",
new_callable=AsyncMock,
side_effect=Exception("boom"),
), patch(
"litellm.proxy.proxy_server.proxy_logging_obj",
proxy_logging_mock,
with (
patch.object(
global_mcp_server_manager,
"get_allowed_mcp_servers",
new_callable=AsyncMock,
return_value=[mock_server.server_id],
),
patch.object(
global_mcp_server_manager,
"get_mcp_server_by_id",
return_value=mock_server,
),
patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers_from_mcp_server_names",
new_callable=AsyncMock,
return_value=[mock_server],
),
patch(
"litellm.proxy._experimental.mcp_server.server.execute_mcp_tool",
new_callable=AsyncMock,
side_effect=Exception("boom"),
),
patch(
"litellm.proxy.proxy_server.proxy_logging_obj",
proxy_logging_mock,
),
):
with pytest.raises(Exception):
await call_mcp_tool(
@@ -2157,23 +2240,30 @@ async def test_get_tools_from_mcp_servers_logs_list_tools_to_spendlogs_when_enab
dummy_logging_obj.model_call_details = {"metadata": {"spend_logs_metadata": {}}}
dummy_logging_obj.async_success_handler = AsyncMock()
with patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
new=AsyncMock(return_value=[server_a]),
), patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=(None, None),
), patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager, patch(
"litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools",
side_effect=lambda tools, _server: tools,
), patch(
"litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions",
new=AsyncMock(side_effect=lambda tools, **_: tools),
), patch(
"litellm.proxy._experimental.mcp_server.server.function_setup",
return_value=(dummy_logging_obj, None),
with (
patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
new=AsyncMock(return_value=[server_a]),
),
patch(
"litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers",
return_value=(None, None),
),
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager,
patch(
"litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools",
side_effect=lambda tools, _server: tools,
),
patch(
"litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions",
new=AsyncMock(side_effect=lambda tools, **_: tools),
),
patch(
"litellm.proxy._experimental.mcp_server.server.function_setup",
return_value=(dummy_logging_obj, None),
),
):
mock_manager._get_tools_from_server = AsyncMock(return_value=[tool_1])
@@ -2188,7 +2278,9 @@ async def test_get_tools_from_mcp_servers_logs_list_tools_to_spendlogs_when_enab
assert tools == [tool_1]
dummy_logging_obj.async_success_handler.assert_awaited_once()
assert dummy_logging_obj.async_success_handler.await_args.kwargs["result"] == [tool_1]
assert dummy_logging_obj.async_success_handler.await_args.kwargs["result"] == [
tool_1
]
spend_meta = dummy_logging_obj.model_call_details["metadata"]["spend_logs_metadata"]
assert spend_meta["tool_count_total"] == 1
@@ -2381,26 +2473,34 @@ async def test_get_tools_from_mcp_servers_injects_stored_oauth2_token():
oauth2_server.extra_headers = None
# Simulate the DB returning a valid credential for this user+server
prefetched_creds = {SERVER_ID: {"access_token": STORED_TOKEN, "server_id": SERVER_ID}}
prefetched_creds = {
SERVER_ID: {"access_token": STORED_TOKEN, "server_id": SERVER_ID}
}
tool_1 = MagicMock()
tool_1.name = "atlassian_test-search"
with patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
new=AsyncMock(return_value=[oauth2_server]),
), patch(
# Patch the bulk prefetch so no real DB connection is needed
"litellm.proxy._experimental.mcp_server.server._prefetch_oauth_creds_for_user",
new=AsyncMock(return_value=prefetched_creds),
) as mock_prefetch, patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager, patch(
"litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools",
side_effect=lambda tools, _server: tools,
), patch(
"litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions",
new=AsyncMock(side_effect=lambda tools, **_: tools),
with (
patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
new=AsyncMock(return_value=[oauth2_server]),
),
patch(
# Patch the bulk prefetch so no real DB connection is needed
"litellm.proxy._experimental.mcp_server.server._prefetch_oauth_creds_for_user",
new=AsyncMock(return_value=prefetched_creds),
) as mock_prefetch,
patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager,
patch(
"litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools",
side_effect=lambda tools, _server: tools,
),
patch(
"litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions",
new=AsyncMock(side_effect=lambda tools, **_: tools),
),
):
mock_manager._get_tools_from_server = AsyncMock(return_value=[tool_1])
@@ -2421,3 +2521,201 @@ async def test_get_tools_from_mcp_servers_injects_stored_oauth2_token():
assert call_kwargs["extra_headers"] == {"Authorization": f"Bearer {STORED_TOKEN}"}
assert tools == [tool_1]
# ---------------------------------------------------------------------------
# _merge_gateway_initialize_instructions + ContextVar / InitializationOptions
# ---------------------------------------------------------------------------
def _make_instruction_server(
server_id="s1",
name="s1",
*,
alias=None,
server_name=None,
instructions=None,
spec_path=None,
url="https://example.com",
):
return MCPServer(
server_id=server_id,
name=name,
alias=alias,
server_name=server_name,
url=url,
transport=MCPTransport.http,
instructions=instructions,
spec_path=spec_path,
)
class TestMergeGatewayInitializeInstructions:
"""Tests for _merge_gateway_initialize_instructions."""
def _merge(self, servers):
try:
from litellm.proxy._experimental.mcp_server.server import (
_merge_gateway_initialize_instructions,
)
except ImportError:
pytest.skip("MCP server not available")
return _merge_gateway_initialize_instructions(servers)
def test_empty_server_list_returns_none(self):
"""No servers yields no instructions."""
assert self._merge([]) is None
def test_single_server_yaml_instructions(self):
"""A single server with YAML instructions returns them verbatim."""
s = _make_instruction_server(instructions="Use add() for sums.")
assert self._merge([s]) == "Use add() for sums."
def test_yaml_instructions_strips_whitespace(self):
"""Leading/trailing whitespace is stripped."""
s = _make_instruction_server(instructions=" padded \n")
assert self._merge([s]) == "padded"
def test_yaml_override_beats_upstream_cache(self):
"""YAML/DB instructions take precedence over upstream cache."""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
global_mcp_server_manager._upstream_initialize_instructions_by_server_id[
"s1"
] = "upstream"
try:
s = _make_instruction_server(instructions="yaml wins")
assert self._merge([s]) == "yaml wins"
finally:
global_mcp_server_manager._upstream_initialize_instructions_by_server_id.pop(
"s1", None
)
def test_upstream_cache_used_when_no_yaml(self):
"""Upstream cached instructions are used when no YAML override is set."""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
global_mcp_server_manager._upstream_initialize_instructions_by_server_id[
"s1"
] = "from upstream"
try:
s = _make_instruction_server(instructions=None)
assert self._merge([s]) == "from upstream"
finally:
global_mcp_server_manager._upstream_initialize_instructions_by_server_id.pop(
"s1", None
)
def test_spec_path_servers_skipped(self):
"""OpenAPI (spec_path) servers do not contribute instructions."""
s = _make_instruction_server(spec_path="/openapi.json", url=None)
assert self._merge([s]) is None
def test_no_instructions_no_cache_returns_none(self):
"""Server with no instructions and no cache yields None."""
s = _make_instruction_server()
assert self._merge([s]) is None
def test_multiple_servers_merged_with_labels(self):
"""Multiple servers get label-prefixed and separator-joined."""
s1 = _make_instruction_server(
server_id="a", name="a", alias="Alpha", instructions="instr A"
)
s2 = _make_instruction_server(
server_id="b", name="b", alias="Beta", instructions="instr B"
)
result = self._merge([s1, s2])
assert result is not None
assert "[Alpha]" in result and "[Beta]" in result
assert "instr A" in result and "instr B" in result
assert "---" in result
def test_single_server_no_label_wrapping(self):
"""A single server's instructions are not wrapped with a label."""
s = _make_instruction_server(alias="MyServer", instructions="single")
result = self._merge([s])
assert result == "single"
assert "[MyServer]" not in result
def test_mixed_yaml_cache_specpath(self):
"""YAML, upstream-cache, and spec_path servers are handled correctly together."""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
global_mcp_server_manager._upstream_initialize_instructions_by_server_id[
"c"
] = "cached C"
try:
s_yaml = _make_instruction_server(
server_id="a", name="a", alias="A", instructions="yaml A"
)
s_spec = _make_instruction_server(
server_id="b", name="b", alias="B", spec_path="/spec.json", url=None
)
s_cached = _make_instruction_server(server_id="c", name="c", alias="C")
result = self._merge([s_yaml, s_spec, s_cached])
assert "yaml A" in result
assert "cached C" in result
assert "[B]" not in result
finally:
global_mcp_server_manager._upstream_initialize_instructions_by_server_id.pop(
"c", None
)
class TestGatewayCreateInitializationOptions:
"""Tests for the patched server.create_initialization_options via ContextVar."""
def test_no_contextvar_returns_default_options(self):
"""When ContextVar is None, instructions are absent."""
try:
from litellm.proxy._experimental.mcp_server.mcp_context import (
_mcp_gateway_initialize_instructions,
)
from litellm.proxy._experimental.mcp_server.server import server
except ImportError:
pytest.skip("MCP server not available")
tok = _mcp_gateway_initialize_instructions.set(None)
try:
opts = server.create_initialization_options()
assert getattr(opts, "instructions", None) is None
finally:
_mcp_gateway_initialize_instructions.reset(tok)
def test_contextvar_set_injects_instructions(self):
"""When ContextVar has a value, it appears in InitializationOptions."""
try:
from litellm.proxy._experimental.mcp_server.mcp_context import (
_mcp_gateway_initialize_instructions,
)
from litellm.proxy._experimental.mcp_server.server import server
except ImportError:
pytest.skip("MCP server not available")
tok = _mcp_gateway_initialize_instructions.set("hello from merge")
try:
opts = server.create_initialization_options()
assert opts.instructions == "hello from merge"
finally:
_mcp_gateway_initialize_instructions.reset(tok)
def test_contextvar_reset_removes_instructions(self):
"""After resetting the ContextVar, instructions disappear."""
try:
from litellm.proxy._experimental.mcp_server.mcp_context import (
_mcp_gateway_initialize_instructions,
)
from litellm.proxy._experimental.mcp_server.server import server
except ImportError:
pytest.skip("MCP server not available")
tok = _mcp_gateway_initialize_instructions.set("temporary")
_mcp_gateway_initialize_instructions.reset(tok)
opts = server.create_initialization_options()
assert getattr(opts, "instructions", None) is None
@@ -43,10 +43,10 @@ def _reload_mcp_manager_module():
# After reload, server.py still holds a stale reference to the old
# global_mcp_server_manager. Update it so tests that exercise server.py
# functions (e.g. _get_tools_from_mcp_servers) use the fresh instance.
server_module = sys.modules.get(
"litellm.proxy._experimental.mcp_server.server"
)
if server_module is not None and hasattr(server_module, "global_mcp_server_manager"):
server_module = sys.modules.get("litellm.proxy._experimental.mcp_server.server")
if server_module is not None and hasattr(
server_module, "global_mcp_server_manager"
):
server_module.global_mcp_server_manager = reloaded.global_mcp_server_manager
return reloaded
@@ -223,9 +223,7 @@ class TestMCPServerManager:
with caplog.at_level(logging.WARNING, logger="LiteLLM"):
await manager.load_servers_from_config(config)
assert any(
"invalid alias 'bad/name'" in message for message in caplog.messages
)
assert any("invalid alias 'bad/name'" in message for message in caplog.messages)
@pytest.mark.asyncio
async def test_load_servers_from_config_accepts_valid_alias(self, caplog):
@@ -492,7 +490,12 @@ class TestMCPServerManager:
mock_client = AsyncMock()
mock_client.list_prompts = AsyncMock(return_value=[mock_prompt])
with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client):
with patch.object(
manager,
"_create_mcp_client",
new_callable=AsyncMock,
return_value=mock_client,
):
prompts = await manager.get_prompts_from_server(server, add_prefix=True)
mock_client.list_prompts.assert_awaited_once()
@@ -520,7 +523,12 @@ class TestMCPServerManager:
mock_client = AsyncMock()
mock_client.get_prompt = AsyncMock(return_value=mock_result)
with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client):
with patch.object(
manager,
"_create_mcp_client",
new_callable=AsyncMock,
return_value=mock_client,
):
result = await manager.get_prompt_from_server(
server=server,
prompt_name="hello",
@@ -551,13 +559,23 @@ class TestMCPServerManager:
mock_client = AsyncMock()
mock_resources = [Resource(name="file", uri="https://example.com/file")]
mock_client.list_resources = AsyncMock(return_value=mock_resources)
prefixed_resources = [Resource(name="alias-server-file", uri="https://example.com/file")]
prefixed_resources = [
Resource(name="alias-server-file", uri="https://example.com/file")
]
with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client) as mock_create_client, patch.object(
manager,
"_create_prefixed_resources",
return_value=prefixed_resources,
) as mock_prefix:
with (
patch.object(
manager,
"_create_mcp_client",
new_callable=AsyncMock,
return_value=mock_client,
) as mock_create_client,
patch.object(
manager,
"_create_prefixed_resources",
return_value=prefixed_resources,
) as mock_prefix,
):
result = await manager.get_resources_from_server(
server=server,
mcp_auth_header="auth",
@@ -602,11 +620,19 @@ class TestMCPServerManager:
)
]
with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client) as mock_create_client, patch.object(
manager,
"_create_prefixed_resource_templates",
return_value=prefixed_templates,
) as mock_prefix:
with (
patch.object(
manager,
"_create_mcp_client",
new_callable=AsyncMock,
return_value=mock_client,
) as mock_create_client,
patch.object(
manager,
"_create_prefixed_resource_templates",
return_value=prefixed_templates,
) as mock_prefix,
):
result = await manager.get_resource_templates_from_server(
server=server,
mcp_auth_header="auth",
@@ -650,7 +676,12 @@ class TestMCPServerManager:
)
mock_client.read_resource = AsyncMock(return_value=read_result)
with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client) as mock_create_client:
with patch.object(
manager,
"_create_mcp_client",
new_callable=AsyncMock,
return_value=mock_client,
) as mock_create_client:
result = await manager.read_resource_from_server(
server=server,
url="https://example.com/resource",
@@ -661,7 +692,9 @@ class TestMCPServerManager:
mock_create_client.assert_called_once()
called_kwargs = mock_create_client.call_args.kwargs
assert called_kwargs["extra_headers"] == {"X-Test": "1", "X-Static": "1"}
mock_client.read_resource.assert_awaited_once_with("https://example.com/resource")
mock_client.read_resource.assert_awaited_once_with(
"https://example.com/resource"
)
assert result is read_result
@pytest.mark.asyncio
@@ -724,22 +757,27 @@ class TestMCPServerManager:
registration_url=None,
)
with patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.get_async_httpx_client",
return_value=mock_client,
), patch.object(
manager,
"_fetch_oauth_metadata_from_resource",
AsyncMock(return_value=([], None)),
), patch.object(
manager,
"_attempt_well_known_discovery",
AsyncMock(return_value=([], None)),
), patch.object(
manager,
"_fetch_authorization_server_metadata",
AsyncMock(return_value=mock_metadata),
) as mock_fetch_auth:
with (
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.get_async_httpx_client",
return_value=mock_client,
),
patch.object(
manager,
"_fetch_oauth_metadata_from_resource",
AsyncMock(return_value=([], None)),
),
patch.object(
manager,
"_attempt_well_known_discovery",
AsyncMock(return_value=([], None)),
),
patch.object(
manager,
"_fetch_authorization_server_metadata",
AsyncMock(return_value=mock_metadata),
) as mock_fetch_auth,
):
result = await manager._descovery_metadata(server_url)
mock_fetch_auth.assert_awaited_once_with(["https://example.com"])
@@ -779,9 +817,8 @@ class TestMCPServerManager:
assert server.scopes == ["config"] # config overrides discovery
assert server.authorization_url == "https://config.example.com/auth"
assert server.token_url == "https://discovered.example.com/token"
assert (
server.registration_url == "https://discovered.example.com/register"
)
assert server.registration_url == "https://discovered.example.com/register"
@pytest.mark.asyncio
async def test_config_oauth_initialize_tool_name_to_mcp_server_name_mapping(self):
manager = MCPServerManager()
@@ -801,7 +838,7 @@ class TestMCPServerManager:
# Initialize the tool mapping
await manager._initialize_tool_name_to_mcp_server_name_mapping()
assert manager.tool_name_to_mcp_server_name_mapping == {}
@pytest.mark.asyncio
async def test_list_tools_handles_missing_server_alias(self):
"""Test that list_tools handles servers without alias gracefully"""
@@ -1017,7 +1054,9 @@ class TestMCPServerManager:
# Capture the extra_headers passed to _create_mcp_client
captured_extra_headers = None
async def capture_create_mcp_client(server, mcp_auth_header, extra_headers, stdio_env):
async def capture_create_mcp_client(
server, mcp_auth_header, extra_headers, stdio_env
):
nonlocal captured_extra_headers
captured_extra_headers = extra_headers
return mock_client
@@ -1314,15 +1353,19 @@ class TestMCPServerManager:
return tool_func
with patch(
"litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.create_tool_function",
side_effect=fake_create_tool_function,
), patch(
"litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.build_input_schema",
return_value={"type": "object", "properties": {}, "required": []},
), patch(
"litellm.proxy._experimental.mcp_server.tool_registry.global_mcp_tool_registry.register_tool",
return_value=None,
with (
patch(
"litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.create_tool_function",
side_effect=fake_create_tool_function,
),
patch(
"litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.build_input_schema",
return_value={"type": "object", "properties": {}, "required": []},
),
patch(
"litellm.proxy._experimental.mcp_server.tool_registry.global_mcp_tool_registry.register_tool",
return_value=None,
),
):
await manager._register_openapi_tools(
spec_path=str(spec_path),
@@ -2161,7 +2204,9 @@ class TestMCPServerManager:
# Register the server and map a tool to it
manager.registry = {"test-server": server}
manager.tool_name_to_mcp_server_name_mapping["test_tool"] = "test-server"
manager.tool_name_to_mcp_server_name_mapping["test-server-test_tool"] = "test-server"
manager.tool_name_to_mcp_server_name_mapping["test-server-test_tool"] = (
"test-server"
)
# Create mock client that tracks call_tool usage
mock_client = AsyncMock()
@@ -2252,11 +2297,16 @@ class TestMCPServerManager:
# Verify MCPRequestHandler.get_allowed_mcp_servers was called with user_api_key_auth
mock_get_allowed.assert_called_once()
call_args = mock_get_allowed.call_args
assert call_args[0][0] is user_api_key_auth # First positional arg should be user_api_key_auth
assert (
call_args[0][0] is user_api_key_auth
) # First positional arg should be user_api_key_auth
assert call_args[0][0].user_id == "user-123"
assert call_args[0][0].object_permission_id == "perm_123"
assert call_args[0][0].object_permission is not None
assert call_args[0][0].object_permission.mcp_servers == ["test_server_1", "test_server_2"]
assert call_args[0][0].object_permission.mcp_servers == [
"test_server_1",
"test_server_2",
]
# Verify result contains the expected servers
assert "test_server_1" in result
@@ -2483,5 +2533,82 @@ class TestHasClientCredentialsOAuth2Flow:
assert server.needs_user_oauth_token is False
# ---------------------------------------------------------------------------
# Upstream initialize-instructions cache
# ---------------------------------------------------------------------------
class TestMCPServerManagerUpstreamInstructionsCache:
"""Tests for the upstream initialize-instructions cache."""
def test_get_returns_none_when_empty(self):
"""Empty cache returns None for any key."""
manager = MCPServerManager()
assert (
manager._upstream_initialize_instructions_by_server_id.get("nonexistent")
is None
)
def test_remember_stores_stripped_value(self):
"""_remember_upstream_initialize_instructions stores a stripped string."""
manager = MCPServerManager()
fake_server = MagicMock(server_id="srv")
fake_client = MagicMock(_last_initialize_instructions=" hello \n")
manager._remember_upstream_initialize_instructions(fake_server, fake_client)
assert (
manager._upstream_initialize_instructions_by_server_id.get("srv") == "hello"
)
def test_remember_ignores_empty_string(self):
"""Whitespace-only instructions are not stored."""
manager = MCPServerManager()
fake_server = MagicMock(server_id="srv")
fake_client = MagicMock(_last_initialize_instructions=" ")
manager._remember_upstream_initialize_instructions(fake_server, fake_client)
assert manager._upstream_initialize_instructions_by_server_id.get("srv") is None
def test_remember_ignores_none(self):
"""None instructions are not stored."""
manager = MCPServerManager()
fake_server = MagicMock(server_id="srv")
fake_client = MagicMock(_last_initialize_instructions=None)
manager._remember_upstream_initialize_instructions(fake_server, fake_client)
assert manager._upstream_initialize_instructions_by_server_id.get("srv") is None
@pytest.mark.asyncio
async def test_load_servers_from_config_clears_cache(self):
"""Reloading config clears any previously cached upstream instructions."""
manager = MCPServerManager()
manager._upstream_initialize_instructions_by_server_id["old"] = "stale"
await manager.load_servers_from_config(
mcp_servers_config={
"fresh_srv": {
"url": "https://example.com",
"instructions": "from yaml",
}
}
)
assert manager._upstream_initialize_instructions_by_server_id.get("old") is None
@pytest.mark.asyncio
async def test_load_servers_reads_instructions_from_config(self):
"""instructions field from YAML config is persisted on the MCPServer."""
manager = MCPServerManager()
await manager.load_servers_from_config(
mcp_servers_config={
"srv_a": {
"url": "https://a.example.com",
"instructions": "A instructions",
},
"srv_b": {
"url": "https://b.example.com",
},
}
)
by_name = {s.server_name: s for s in manager.config_mcp_servers.values()}
assert "srv_a" in by_name and by_name["srv_a"].instructions == "A instructions"
assert "srv_b" in by_name and by_name["srv_b"].instructions is None
if __name__ == "__main__":
pytest.main([__file__])
@@ -399,7 +399,9 @@ class TestMCPServerManagerSigV4:
server = next(iter(manager.config_mcp_servers.values()))
assert server.auth_type == MCPAuth.aws_sigv4
assert server.aws_access_key_id == "AKIAIOSFODNN7EXAMPLE"
assert server.aws_secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
assert (
server.aws_secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
)
assert server.aws_region_name == "us-east-1"
assert server.aws_service_name == "bedrock-agentcore"
@@ -529,7 +531,9 @@ class TestMCPServerManagerSigV4:
"aws_session_name": "my-session",
}
result = manager._extract_aws_credentials(creds, credentials_are_encrypted=False)
result = manager._extract_aws_credentials(
creds, credentials_are_encrypted=False
)
assert result["aws_role_name"] == "arn:aws:iam::123456789012:role/TestRole"
assert result["aws_session_name"] == "my-session"
@@ -615,12 +619,15 @@ class TestCredentialMergeOnUpdate:
credentials={"aws_region_name": "eu-west-1"},
)
with patch(
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
return_value=None,
), patch(
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
side_effect=lambda value, new_encryption_key: value,
with (
patch(
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
return_value=None,
),
patch(
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
side_effect=lambda value, new_encryption_key: value,
),
):
await update_mcp_server(mock_prisma, data, "test-user")
@@ -685,12 +692,15 @@ class TestCredentialMergeOnUpdate:
credentials={"aws_region_name": "us-east-1"},
)
with patch(
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
return_value=None,
), patch(
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
side_effect=lambda value, new_encryption_key: value,
with (
patch(
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
return_value=None,
),
patch(
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
side_effect=lambda value, new_encryption_key: value,
),
):
await update_mcp_server(mock_prisma, data, "test-user")
@@ -728,12 +738,15 @@ class TestCredentialMergeOnUpdate:
credentials={"auth_value": "my-key"},
)
with patch(
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
return_value=None,
), patch(
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
side_effect=lambda value, new_encryption_key: f"enc:{value}",
with (
patch(
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
return_value=None,
),
patch(
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
side_effect=lambda value, new_encryption_key: f"enc:{value}",
),
):
await update_mcp_server(mock_prisma, data, "test-user")
@@ -772,12 +785,15 @@ class TestCredentialMergeOnUpdate:
credentials={"scopes": ["read", "write"]},
)
with patch(
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
return_value=None,
), patch(
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
side_effect=lambda value, new_encryption_key: value,
with (
patch(
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
return_value=None,
),
patch(
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
side_effect=lambda value, new_encryption_key: value,
),
):
await update_mcp_server(mock_prisma, data, "test-user")
@@ -803,7 +819,9 @@ class TestSigV4BuildFromTable:
table_record.server_name = "sigv4_server"
table_record.alias = None
table_record.description = None
table_record.url = "https://bedrock-agentcore.us-east-1.amazonaws.com/invocations"
table_record.url = (
"https://bedrock-agentcore.us-east-1.amazonaws.com/invocations"
)
table_record.spec_path = None
table_record.transport = "http"
table_record.auth_type = "aws_sigv4"
@@ -838,6 +856,7 @@ class TestSigV4BuildFromTable:
table_record.tool_name_to_description = None
table_record.byok_api_key_help_url = None
table_record.oauth2_flow = None
table_record.instructions = None
manager = MCPServerManager()
@@ -895,6 +914,7 @@ class TestSigV4BuildFromTable:
table_record.tool_name_to_description = None
table_record.byok_api_key_help_url = None
table_record.oauth2_flow = None
table_record.instructions = None
manager = MCPServerManager()
@@ -934,7 +954,9 @@ class TestDecryptCredentials:
with patch(
"litellm.proxy._experimental.mcp_server.db.decrypt_value_helper",
side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace("enc:", ""),
side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace(
"enc:", ""
),
):
result = decrypt_credentials(credentials=creds)
@@ -956,7 +978,9 @@ class TestDecryptCredentials:
with patch(
"litellm.proxy._experimental.mcp_server.db.decrypt_value_helper",
side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace("enc:", ""),
side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace(
"enc:", ""
),
):
result = decrypt_credentials(credentials=creds)
@@ -988,15 +1012,21 @@ class TestRotateCredentials:
)
mock_prisma.db.litellm_mcpservertable.update = AsyncMock()
with patch(
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
return_value="old-key",
), patch(
"litellm.proxy._experimental.mcp_server.db.decrypt_value_helper",
side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace("enc_old:", ""),
), patch(
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
side_effect=lambda value, new_encryption_key: f"enc_new:{value}",
with (
patch(
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
return_value="old-key",
),
patch(
"litellm.proxy._experimental.mcp_server.db.decrypt_value_helper",
side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace(
"enc_old:", ""
),
),
patch(
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
side_effect=lambda value, new_encryption_key: f"enc_new:{value}",
),
):
await rotate_mcp_server_credentials_master_key(
mock_prisma, "admin", "new-key"
@@ -307,3 +307,83 @@ async def test_lock_takeover_race_condition(mock_redis):
cronjob_id="test_job",
)
assert result2 == False
@pytest.mark.asyncio
async def test_release_lock_uses_atomic_compare_delete_script_when_available(
pod_lock_manager, mock_redis
):
"""
Test that release_lock prefers atomic compare-and-delete Lua script when
redis cache exposes script registration.
"""
script_callable = AsyncMock(return_value=1)
mock_redis.async_register_script = MagicMock(return_value=script_callable)
await pod_lock_manager.release_lock(cronjob_id="test_job")
lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
mock_redis.async_register_script.assert_called_once_with(
PodLockManager._COMPARE_AND_DELETE_LOCK_SCRIPT
)
script_callable.assert_called_once_with(
keys=[lock_key], args=[pod_lock_manager.pod_id]
)
mock_redis.async_get_cache.assert_not_called()
mock_redis.async_delete_cache.assert_not_called()
@pytest.mark.asyncio
async def test_release_lock_reuses_registered_script(pod_lock_manager, mock_redis):
"""
Test script registration is cached on manager instance and reused.
"""
script_callable = AsyncMock(return_value=0)
mock_redis.async_register_script = MagicMock(return_value=script_callable)
await pod_lock_manager.release_lock(cronjob_id="test_job")
await pod_lock_manager.release_lock(cronjob_id="test_job")
assert mock_redis.async_register_script.call_count == 1
@pytest.mark.asyncio
async def test_release_lock_lua_path_emits_released_event(
pod_lock_manager, mock_redis
):
"""
Test that _emit_released_lock_event is called when the Lua path returns 1
(successful release).
"""
script_callable = AsyncMock(return_value=1)
mock_redis.async_register_script = MagicMock(return_value=script_callable)
with patch.object(pod_lock_manager, "_emit_released_lock_event") as mock_emit:
await pod_lock_manager.release_lock(cronjob_id="test_job")
mock_emit.assert_called_once_with(
cronjob_id="test_job", pod_id=pod_lock_manager.pod_id
)
@pytest.mark.asyncio
async def test_release_lock_falls_back_to_get_del_when_lua_execution_fails(
pod_lock_manager, mock_redis
):
"""
Test that release_lock falls back to GET+DEL when Lua script execution
raises (e.g. Redis restart cleared loaded scripts).
"""
script_callable = AsyncMock(side_effect=Exception("NOSCRIPT"))
mock_redis.async_register_script = MagicMock(return_value=script_callable)
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id
mock_redis.async_delete_cache.return_value = 1
await pod_lock_manager.release_lock(cronjob_id="test_job")
# Lua failed — should have fallen back to GET+DEL
lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
mock_redis.async_get_cache.assert_called_once_with(lock_key)
mock_redis.async_delete_cache.assert_called_once_with(lock_key)
# Cached script handle should be reset so next call re-registers
assert pod_lock_manager._release_lock_script is None
@@ -0,0 +1,155 @@
"""
Tests for create_missing_views exception handling fix.
Verifies that real DB errors (auth failures, connection errors, etc.)
are re-raised instead of being silently swallowed, while genuine
"view not found" errors still trigger view creation.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, call
@pytest.mark.asyncio
async def test_create_views_reraises_connection_error():
"""should re-raise exceptions that are NOT 'does not exist' errors (e.g. connection errors)."""
from litellm.proxy.db.create_views import create_missing_views
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(
side_effect=Exception("connection refused: unable to connect to database")
)
mock_db.execute_raw = AsyncMock()
with pytest.raises(Exception, match="connection refused"):
await create_missing_views(mock_db)
mock_db.execute_raw.assert_not_called()
@pytest.mark.asyncio
async def test_create_views_reraises_permission_error():
"""should re-raise permission denied errors, not treat them as missing views."""
from litellm.proxy.db.create_views import create_missing_views
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(
side_effect=Exception(
"permission denied for table LiteLLM_VerificationTokenView"
)
)
mock_db.execute_raw = AsyncMock()
with pytest.raises(Exception, match="permission denied"):
await create_missing_views(mock_db)
mock_db.execute_raw.assert_not_called()
@pytest.mark.asyncio
async def test_create_views_creates_view_on_does_not_exist():
"""should call execute_raw to create view when error contains 'does not exist'."""
from litellm.proxy.db.create_views import create_missing_views
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(
side_effect=[
Exception('relation "LiteLLM_VerificationTokenView" does not exist'),
None, # MonthlyGlobalSpend exists
None, # Last30dKeysBySpend exists
None, # Last30dModelsBySpend exists
None, # MonthlyGlobalSpendPerKey exists
None, # MonthlyGlobalSpendPerUserPerKey exists
None, # DailyTagSpend exists
None, # Last30dTopEndUsersSpend exists
]
)
mock_db.execute_raw = AsyncMock(return_value=None)
await create_missing_views(mock_db)
mock_db.execute_raw.assert_called_once()
created_sql = mock_db.execute_raw.call_args[0][0]
assert 'CREATE VIEW "LiteLLM_VerificationTokenView"' in created_sql
@pytest.mark.asyncio
async def test_create_views_creates_view_on_undefined_error():
"""should treat 'undefined' errors as 'view not found' and attempt creation."""
from litellm.proxy.db.create_views import create_missing_views
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(
side_effect=[
Exception("undefined table LiteLLM_VerificationTokenView"),
None,
None,
None,
None,
None,
None,
None,
]
)
mock_db.execute_raw = AsyncMock(return_value=None)
await create_missing_views(mock_db)
mock_db.execute_raw.assert_called_once()
@pytest.mark.asyncio
async def test_create_views_skips_creation_when_view_exists():
"""should not call execute_raw when all views already exist."""
from litellm.proxy.db.create_views import create_missing_views
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=[{"?column?": 1}])
mock_db.execute_raw = AsyncMock()
await create_missing_views(mock_db)
mock_db.execute_raw.assert_not_called()
@pytest.mark.asyncio
async def test_create_views_reraises_undefined_function_error():
"""should re-raise 'undefined function' errors — bare 'undefined' is too broad
and would previously misclassify DB function errors as missing-view signals."""
from litellm.proxy.db.create_views import create_missing_views
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(
side_effect=Exception("ERROR: undefined function pg_get_viewdef()")
)
mock_db.execute_raw = AsyncMock()
with pytest.raises(Exception, match="undefined function"):
await create_missing_views(mock_db)
mock_db.execute_raw.assert_not_called()
@pytest.mark.asyncio
async def test_create_views_creates_view_on_undefined_table_error():
"""should treat 'undefined table' as a missing-view signal and attempt creation."""
from litellm.proxy.db.create_views import create_missing_views
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(
side_effect=[
Exception('undefined table "LiteLLM_VerificationTokenView"'),
None,
None,
None,
None,
None,
None,
None,
]
)
mock_db.execute_raw = AsyncMock(return_value=None)
await create_missing_views(mock_db)
mock_db.execute_raw.assert_called_once()
@@ -1190,6 +1190,483 @@ async def test_bedrock_guardrail_blocked_content_with_masking_enabled():
print("✅ BLOCKED content with masking enabled raises exception correctly")
# ──────────────────────────────────────────────────────────────────────────────
# Null-safety tests for Bedrock guardrail responses
#
# The Bedrock ApplyGuardrail API can return explicit null/None for list fields
# such as "regexes", "piiEntities", "topics", "filters", "customWords", and
# "managedWordLists" when a particular policy category is present in the
# assessment but has no matches.
#
# Python's dict.get("key", []) returns None (NOT []) when the key exists with
# a None value. The `or []` fallback ensures we always iterate over a list.
#
# Without the fix, iterating over None raises:
# TypeError: 'NoneType' object is not iterable
# which surfaces to callers as:
# openai.InternalServerError: Error code: 500
# {'error': {'message': "Bedrock guardrail failed: 'NoneType' object is not iterable", ...}}
# ──────────────────────────────────────────────────────────────────────────────
class TestRedactPiiMatchesNullSafety:
"""Tests for _redact_pii_matches handling of null/None list fields from Bedrock API."""
@pytest.mark.asyncio
async def test_should_handle_null_regexes_in_sensitive_info_policy(self):
"""Bedrock can return regexes: null while piiEntities has data.
Real-world scenario: guardrail detects PII (e.g. EMAIL) but has no
custom regex patterns configured, so the API returns regexes: null.
"""
response = {
"action": "NONE",
"actionReason": "No action.",
"assessments": [
{
"sensitiveInformationPolicy": {
"piiEntities": [
{
"action": "NONE",
"detected": True,
"match": "joebloggs@gmail.com",
"type": "EMAIL",
}
],
"regexes": None, # Explicit null from Bedrock API
},
}
],
}
# Should not raise TypeError: 'NoneType' object is not iterable
redacted = _redact_pii_matches(response)
# PII match should be redacted
pii = redacted["assessments"][0]["sensitiveInformationPolicy"]["piiEntities"]
assert pii[0]["match"] == "[REDACTED]"
assert pii[0]["type"] == "EMAIL"
@pytest.mark.asyncio
async def test_should_handle_null_pii_entities_in_sensitive_info_policy(self):
"""Bedrock can return piiEntities: null while regexes has data."""
response = {
"action": "NONE",
"assessments": [
{
"sensitiveInformationPolicy": {
"piiEntities": None, # null from Bedrock API
"regexes": [
{
"name": "CUSTOM_PATTERN",
"match": "secret-abc-123",
"action": "BLOCKED",
}
],
},
}
],
}
redacted = _redact_pii_matches(response)
regexes = redacted["assessments"][0]["sensitiveInformationPolicy"]["regexes"]
assert regexes[0]["match"] == "[REDACTED]"
@pytest.mark.asyncio
async def test_should_handle_null_custom_words_and_managed_words(self):
"""Bedrock can return null for customWords and managedWordLists in wordPolicy."""
response = {
"action": "NONE",
"assessments": [
{
"wordPolicy": {
"customWords": None, # null from Bedrock API
"managedWordLists": None, # null from Bedrock API
},
}
],
}
# Should not raise TypeError
redacted = _redact_pii_matches(response)
# Values should remain None (no crash)
assert redacted["assessments"][0]["wordPolicy"]["customWords"] is None
assert redacted["assessments"][0]["wordPolicy"]["managedWordLists"] is None
@pytest.mark.asyncio
async def test_should_handle_null_assessments_list(self):
"""Bedrock can return assessments: null."""
response = {
"action": "NONE",
"assessments": None, # null from Bedrock API
}
# Should not raise TypeError
redacted = _redact_pii_matches(response)
assert redacted["assessments"] is None
@pytest.mark.asyncio
async def test_should_handle_all_null_policy_sub_lists_together(self):
"""All sub-list fields are null at the same time — worst-case scenario."""
response = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [
{
"sensitiveInformationPolicy": {
"piiEntities": None,
"regexes": None,
},
"wordPolicy": {
"customWords": None,
"managedWordLists": None,
},
"topicPolicy": None,
"contentPolicy": None,
"contextualGroundingPolicy": None,
}
],
}
# Should not raise any exception
redacted = _redact_pii_matches(response)
assert redacted is not None
class TestShouldRaiseGuardrailBlockedExceptionNullSafety:
"""Tests for _should_raise_guardrail_blocked_exception handling of null list fields."""
def _create_guardrail(self) -> BedrockGuardrail:
return BedrockGuardrail(
guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT"
)
@pytest.mark.asyncio
async def test_should_handle_all_null_policy_sub_lists(self):
"""All policy sub-lists are null — should not crash, should return False."""
guardrail = self._create_guardrail()
response = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [
{
"topicPolicy": {
"topics": None, # null from Bedrock API
},
"contentPolicy": {
"filters": None, # null
},
"wordPolicy": {
"customWords": None, # null
"managedWordLists": None, # null
},
"sensitiveInformationPolicy": {
"piiEntities": None, # null
"regexes": None, # null
},
"contextualGroundingPolicy": {
"filters": None, # null
},
}
],
}
# No BLOCKED actions found (all lists null) → should return False
result = guardrail._should_raise_guardrail_blocked_exception(response)
assert result is False
@pytest.mark.asyncio
async def test_should_detect_blocked_despite_other_null_lists(self):
"""A mix of null lists and a real BLOCKED action — should still detect it."""
guardrail = self._create_guardrail()
response = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [
{
"topicPolicy": {
"topics": None, # null — should not crash
},
"contentPolicy": {
"filters": [
{
"type": "HATE",
"confidence": "HIGH",
"action": "BLOCKED",
}
],
},
"wordPolicy": {
"customWords": None, # null
"managedWordLists": None, # null
},
"sensitiveInformationPolicy": {
"piiEntities": None, # null
"regexes": None, # null
},
"contextualGroundingPolicy": None, # entire policy is null
}
],
}
# Should return True because contentPolicy has a BLOCKED filter
result = guardrail._should_raise_guardrail_blocked_exception(response)
assert result is True
@pytest.mark.asyncio
async def test_should_handle_null_assessments_list(self):
"""assessments itself is null — should return False."""
guardrail = self._create_guardrail()
response = {
"action": "GUARDRAIL_INTERVENED",
"assessments": None, # null from Bedrock API
}
result = guardrail._should_raise_guardrail_blocked_exception(response)
assert result is False
@pytest.mark.asyncio
async def test_should_handle_null_topics_with_blocked_word_policy(self):
"""topics is null but wordPolicy has a BLOCKED customWord."""
guardrail = self._create_guardrail()
response = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [
{
"topicPolicy": {
"topics": None,
},
"wordPolicy": {
"customWords": [
{"match": "badword", "action": "BLOCKED"}
],
"managedWordLists": None,
},
}
],
}
result = guardrail._should_raise_guardrail_blocked_exception(response)
assert result is True
@pytest.mark.asyncio
async def test_should_handle_null_pii_with_blocked_regex(self):
"""piiEntities is null but regexes has a BLOCKED match."""
guardrail = self._create_guardrail()
response = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [
{
"sensitiveInformationPolicy": {
"piiEntities": None,
"regexes": [
{"name": "SSN", "match": "123-45-6789", "action": "BLOCKED"}
],
},
}
],
}
result = guardrail._should_raise_guardrail_blocked_exception(response)
assert result is True
@pytest.mark.asyncio
async def test_should_handle_null_grounding_filters(self):
"""contextualGroundingPolicy.filters is null — should not crash."""
guardrail = self._create_guardrail()
response = {
"action": "GUARDRAIL_INTERVENED",
"assessments": [
{
"contextualGroundingPolicy": {
"filters": None,
},
}
],
}
result = guardrail._should_raise_guardrail_blocked_exception(response)
assert result is False
@pytest.mark.asyncio
async def test_should_not_crash_when_action_is_not_intervened(self):
"""If action != GUARDRAIL_INTERVENED, null lists should never be reached."""
guardrail = self._create_guardrail()
response = {
"action": "NONE",
"assessments": [
{
"sensitiveInformationPolicy": {
"piiEntities": None,
"regexes": None,
},
}
],
}
result = guardrail._should_raise_guardrail_blocked_exception(response)
assert result is False
class TestApplyGuardrailNullSafety:
"""Tests for apply_guardrail handling of null/None texts input."""
@pytest.mark.asyncio
async def test_should_handle_none_texts_in_inputs(self):
"""inputs[\"texts\"] is explicitly None — should not crash."""
guardrail = BedrockGuardrail(
guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT"
)
inputs = {"texts": None} # Explicit None
mock_credentials = MagicMock()
with patch.object(
guardrail.async_handler, "post", new_callable=AsyncMock
) as mock_post, patch.object(
guardrail, "_load_credentials", return_value=(mock_credentials, "us-east-1")
), patch.object(
guardrail, "_prepare_request", return_value=MagicMock()
):
# With empty texts (from None → []), no Bedrock API call should be made
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data={},
input_type="request",
)
# Should return empty texts without crashing
assert result.get("texts") == []
# No Bedrock API call should be made for empty input
mock_post.assert_not_called()
@pytest.mark.asyncio
async def test_should_handle_missing_texts_key(self):
"""inputs has no \"texts\" key at all — should not crash."""
guardrail = BedrockGuardrail(
guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT"
)
inputs = {} # No "texts" key
mock_credentials = MagicMock()
with patch.object(
guardrail.async_handler, "post", new_callable=AsyncMock
) as mock_post, patch.object(
guardrail, "_load_credentials", return_value=(mock_credentials, "us-east-1")
), patch.object(
guardrail, "_prepare_request", return_value=MagicMock()
):
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data={},
input_type="request",
)
assert result.get("texts") == []
mock_post.assert_not_called()
@pytest.mark.asyncio
async def test_bedrock_guardrail_blocked_vs_anonymized_actions():
"""Test that BLOCKED actions raise exceptions but ANONYMIZED actions do not"""
guardrail = BedrockGuardrail(
guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT"
)
# Test 1: ANONYMIZED action should NOT raise exception
anonymized_response = {
"action": "GUARDRAIL_INTERVENED",
"outputs": [{"text": "Hello, my phone number is {PHONE}"}],
"assessments": [
{
"sensitiveInformationPolicy": {
"piiEntities": [
{
"type": "PHONE",
"match": "+1 412 555 1212",
"action": "ANONYMIZED",
}
]
}
}
],
}
should_raise = guardrail._should_raise_guardrail_blocked_exception(
anonymized_response
)
assert should_raise is False, "ANONYMIZED actions should not raise exceptions"
# Test 2: BLOCKED action should raise exception
blocked_response = {
"action": "GUARDRAIL_INTERVENED",
"outputs": [{"text": "I can't provide that information."}],
"assessments": [
{
"topicPolicy": {
"topics": [
{"name": "Sensitive Topic", "type": "DENY", "action": "BLOCKED"}
]
}
}
],
}
should_raise = guardrail._should_raise_guardrail_blocked_exception(blocked_response)
assert should_raise is True, "BLOCKED actions should raise exceptions"
# Test 3: Mixed actions - should raise if ANY action is BLOCKED
mixed_response = {
"action": "GUARDRAIL_INTERVENED",
"outputs": [{"text": "I can't provide that information."}],
"assessments": [
{
"sensitiveInformationPolicy": {
"piiEntities": [
{
"type": "PHONE",
"match": "+1 412 555 1212",
"action": "ANONYMIZED",
}
]
},
"topicPolicy": {
"topics": [
{"name": "Blocked Topic", "type": "DENY", "action": "BLOCKED"}
]
},
}
],
}
should_raise = guardrail._should_raise_guardrail_blocked_exception(mixed_response)
assert (
should_raise is True
), "Mixed actions with any BLOCKED should raise exceptions"
# Test 4: NONE action should not raise exception
none_response = {
"action": "NONE",
"outputs": [],
"assessments": [],
}
should_raise = guardrail._should_raise_guardrail_blocked_exception(none_response)
assert should_raise is False, "NONE action should not raise exceptions"
print("\u2705 BLOCKED vs ANONYMIZED actions test passed")
# ---------------------------------------------------------------------------
# L3: _extract_blocked_assessments + _get_http_exception_for_blocked_guardrail
# Regression coverage for case 2026-04-10-internal-bedrock-guardrail-streaming-error.
@@ -1338,4 +1815,3 @@ def test_get_http_exception_no_blocked_assessments_omits_field():
assert isinstance(exc, HTTPException)
assert "assessments" not in exc.detail
assert exc.detail["guardrailIdentifier"] == "amgllac6xf3r"
@@ -4,6 +4,7 @@ import pytest
from fastapi import HTTPException
from litellm.integrations.custom_guardrail import ModifyResponseException
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_hooks.grayswan.grayswan import (
GraySwanGuardrail,
GraySwanGuardrailAPIError,
@@ -435,7 +436,10 @@ def test_format_violation_message() -> None:
assert "Gray Swan Cygnal Guardrail" in message
assert "the input query has a violation score of 0.85" in message
assert "violating the rule(s): 1, 3, 5" in message
assert "Mutation effort to make the harmful intention disguised was DETECTED" in message
assert (
"Mutation effort to make the harmful intention disguised was DETECTED"
in message
)
# IPI should not be in message since it's False
assert "Indirect Prompt Injection was DETECTED" not in message
@@ -446,4 +450,58 @@ def test_format_violation_message() -> None:
assert "Gray Swan Cygnal Guardrail" in message
assert "the model response has a violation score of 0.85" in message
assert "violating the rule(s): 1, 3, 5" in message
assert "Mutation effort to make the harmful intention disguised was DETECTED" in message
assert (
"Mutation effort to make the harmful intention disguised was DETECTED"
in message
)
def test_prepare_payload_includes_litellm_metadata(
grayswan_guardrail: GraySwanGuardrail,
) -> None:
"""Verify _prepare_payload forwards litellm_metadata from request_data."""
messages = [{"role": "user", "content": "hello"}]
request_data = {
"litellm_metadata": {
"user_api_key_user_id": "user-123",
"user_api_key_team_id": "team-456",
"user_api_key_spend": 0,
}
}
payload = grayswan_guardrail._prepare_payload(messages, {}, request_data)
assert payload is not None
assert "litellm_metadata" in payload
assert payload["litellm_metadata"]["user_api_key_user_id"] == "user-123"
assert payload["litellm_metadata"]["user_api_key_team_id"] == "team-456"
def test_ensure_litellm_metadata_populates_from_user_api_key_dict() -> None:
"""Verify _ensure_litellm_metadata populates litellm_metadata."""
from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import (
_ensure_litellm_metadata,
)
user_auth = UserAPIKeyAuth(user_id="u1", team_id="t1", api_key="sk-test-hashed")
data: dict = {}
_ensure_litellm_metadata(data, user_auth)
assert "litellm_metadata" in data
assert data["litellm_metadata"]["user_api_key_user_id"] == "u1"
assert data["litellm_metadata"]["user_api_key_team_id"] == "t1"
def test_ensure_litellm_metadata_noop_when_already_present() -> None:
"""Verify _ensure_litellm_metadata does not overwrite existing litellm_metadata."""
from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import (
_ensure_litellm_metadata,
)
user_auth = UserAPIKeyAuth(user_id="should-not-appear")
data: dict = {"litellm_metadata": {"existing": "value"}}
_ensure_litellm_metadata(data, user_auth)
assert data["litellm_metadata"] == {"existing": "value"}
@@ -150,13 +150,19 @@ class TestNomaV2Configuration:
application_id="dynamic-app",
)
payload["request_data"]["metadata"]["headers"]["x-noma-application-id"] = "mutated-value"
payload["request_data"]["metadata"]["headers"][
"x-noma-application-id"
] = "mutated-value"
payload["request_data"]["messages"][0]["content"] = "changed-content"
assert request_data["metadata"]["headers"]["x-noma-application-id"] == "header-app"
assert (
request_data["metadata"]["headers"]["x-noma-application-id"] == "header-app"
)
assert request_data["messages"][0]["content"] == "hello"
def test_build_scan_payload_passes_model_call_details_as_is(self, noma_v2_guardrail):
def test_build_scan_payload_passes_model_call_details_as_is(
self, noma_v2_guardrail
):
class _LoggingObj:
def __init__(self) -> None:
self.model_call_details = {
@@ -193,7 +199,9 @@ class TestNomaV2Configuration:
assert request_data["litellm_logging_obj"] == "<Logging object>"
@pytest.mark.asyncio
async def test_call_noma_scan_sanitizes_response_model_dump_object(self, noma_v2_guardrail):
async def test_call_noma_scan_sanitizes_response_model_dump_object(
self, noma_v2_guardrail
):
import json
class _FakeModelResponse:
@@ -221,7 +229,9 @@ class TestNomaV2Configuration:
json.dumps(sent_payload)
assert sent_payload["request_data"]["response"]["id"] == "resp-1"
def test_sanitize_payload_for_transport_falls_back_to_safe_dumps(self, noma_v2_guardrail):
def test_sanitize_payload_for_transport_falls_back_to_safe_dumps(
self, noma_v2_guardrail
):
with patch(
"litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.json.dumps",
side_effect=TypeError("cannot serialize"),
@@ -230,12 +240,16 @@ class TestNomaV2Configuration:
"litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.safe_dumps",
return_value='{"fallback": true}',
) as mock_safe_dumps:
sanitized = noma_v2_guardrail._sanitize_payload_for_transport({"inputs": {"texts": ["hello"]}})
sanitized = noma_v2_guardrail._sanitize_payload_for_transport(
{"inputs": {"texts": ["hello"]}}
)
mock_safe_dumps.assert_called_once()
assert sanitized == {"fallback": True}
def test_sanitize_payload_for_transport_logs_warning_when_payload_becomes_empty(self, noma_v2_guardrail):
def test_sanitize_payload_for_transport_logs_warning_when_payload_becomes_empty(
self, noma_v2_guardrail
):
with patch(
"litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.safe_json_loads",
return_value={},
@@ -243,14 +257,18 @@ class TestNomaV2Configuration:
with patch(
"litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.verbose_proxy_logger.warning"
) as mock_warning:
sanitized = noma_v2_guardrail._sanitize_payload_for_transport({"inputs": {"texts": ["hello"]}})
sanitized = noma_v2_guardrail._sanitize_payload_for_transport(
{"inputs": {"texts": ["hello"]}}
)
assert sanitized == {}
mock_warning.assert_called_once_with(
"Noma v2 guardrail: payload serialization failed, falling back to empty payload"
)
def test_sanitize_payload_for_transport_logs_warning_on_non_dict_output(self, noma_v2_guardrail):
def test_sanitize_payload_for_transport_logs_warning_on_non_dict_output(
self, noma_v2_guardrail
):
with patch(
"litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.safe_json_loads",
return_value=["not-a-dict"],
@@ -258,7 +276,9 @@ class TestNomaV2Configuration:
with patch(
"litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.verbose_proxy_logger.warning"
) as mock_warning:
sanitized = noma_v2_guardrail._sanitize_payload_for_transport({"inputs": {"texts": ["hello"]}})
sanitized = noma_v2_guardrail._sanitize_payload_for_transport(
{"inputs": {"texts": ["hello"]}}
)
assert sanitized == {}
mock_warning.assert_called_once_with(
@@ -271,7 +291,9 @@ class TestNomaV2Configuration:
class TestNomaV2ActionBehavior:
def test_resolve_action_from_response_raises_on_unknown_action(self, noma_v2_guardrail):
def test_resolve_action_from_response_raises_on_unknown_action(
self, noma_v2_guardrail
):
with pytest.raises(ValueError, match="missing valid action"):
noma_v2_guardrail._resolve_action_from_response({"action": "INVALID"})
@@ -296,7 +318,9 @@ class TestNomaV2ActionBehavior:
assert result == inputs
@pytest.mark.asyncio
async def test_native_action_guardrail_intervened_updates_supported_fields(self, noma_v2_guardrail):
async def test_native_action_guardrail_intervened_updates_supported_fields(
self, noma_v2_guardrail
):
inputs = {
"texts": ["Name: Jane"],
"images": ["https://old.example/image.png"],
@@ -322,7 +346,10 @@ class TestNomaV2ActionBehavior:
{
"id": "call_1",
"type": "function",
"function": {"name": "new_tool", "arguments": '{"safe":"true"}'},
"function": {
"name": "new_tool",
"arguments": '{"safe":"true"}',
},
}
],
}
@@ -336,7 +363,9 @@ class TestNomaV2ActionBehavior:
assert result["texts"] == ["Name: *******"]
assert result["images"] == ["https://new.example/image.png"]
assert result["tools"] == [{"type": "function", "function": {"name": "new_tool"}}]
assert result["tools"] == [
{"type": "function", "function": {"name": "new_tool"}}
]
assert result["tool_calls"] == [
{
"id": "call_1",
@@ -367,7 +396,9 @@ class TestNomaV2ActionBehavior:
assert exc_info.value.detail["details"]["blocked_reason"] == "blocked by policy"
@pytest.mark.asyncio
async def test_intervened_without_modifications_returns_original_inputs(self, noma_v2_guardrail):
async def test_intervened_without_modifications_returns_original_inputs(
self, noma_v2_guardrail
):
inputs = {"texts": ["Name: Jane"]}
with patch.object(
noma_v2_guardrail,
@@ -464,7 +495,9 @@ class TestNomaV2ApplicationIdResolution:
assert payload["application_id"] == "dynamic-app"
@pytest.mark.asyncio
async def test_apply_guardrail_uses_configured_application_id(self, noma_v2_guardrail):
async def test_apply_guardrail_uses_configured_application_id(
self, noma_v2_guardrail
):
call_mock = AsyncMock(return_value={"action": "NONE"})
with patch.object(
noma_v2_guardrail,
@@ -482,7 +515,86 @@ class TestNomaV2ApplicationIdResolution:
assert payload["application_id"] == "test-app"
@pytest.mark.asyncio
async def test_apply_guardrail_omits_application_id_when_not_explicit(self):
async def test_apply_guardrail_falls_back_to_key_alias_from_litellm_metadata(
self, noma_v2_guardrail
):
"""When no explicit application_id is set, fall back to user_api_key_alias
so that each API key gets its own application entry in the Noma dashboard."""
noma_v2_guardrail.application_id = None
call_mock = AsyncMock(return_value={"action": "NONE"})
request_data = {
"metadata": {},
"litellm_metadata": {"user_api_key_alias": "test-key-alias"},
}
with patch.object(
noma_v2_guardrail,
"get_guardrail_dynamic_request_body_params",
return_value={},
):
with patch.object(noma_v2_guardrail, "_call_noma_scan", call_mock):
await noma_v2_guardrail.apply_guardrail(
inputs={"texts": ["hello"]},
request_data=request_data,
input_type="request",
)
payload = call_mock.call_args.kwargs["payload"]
assert payload["application_id"] == "test-key-alias"
@pytest.mark.asyncio
async def test_apply_guardrail_falls_back_to_key_alias_from_metadata(
self, noma_v2_guardrail
):
"""user_api_key_alias in metadata (set by proxy_server.py) is also resolved."""
noma_v2_guardrail.application_id = None
call_mock = AsyncMock(return_value={"action": "NONE"})
request_data = {
"metadata": {"user_api_key_alias": "test-service-key"},
}
with patch.object(
noma_v2_guardrail,
"get_guardrail_dynamic_request_body_params",
return_value={},
):
with patch.object(noma_v2_guardrail, "_call_noma_scan", call_mock):
await noma_v2_guardrail.apply_guardrail(
inputs={"texts": ["hello"]},
request_data=request_data,
input_type="request",
)
payload = call_mock.call_args.kwargs["payload"]
assert payload["application_id"] == "test-service-key"
@pytest.mark.asyncio
async def test_apply_guardrail_configured_application_id_takes_precedence_over_key_alias(
self, noma_v2_guardrail
):
"""Explicit application_id (config/env) wins over key_alias fallback."""
call_mock = AsyncMock(return_value={"action": "NONE"})
request_data = {
"metadata": {"user_api_key_alias": "should-not-be-used"},
}
with patch.object(
noma_v2_guardrail,
"get_guardrail_dynamic_request_body_params",
return_value={},
):
with patch.object(noma_v2_guardrail, "_call_noma_scan", call_mock):
await noma_v2_guardrail.apply_guardrail(
inputs={"texts": ["hello"]},
request_data=request_data,
input_type="request",
)
payload = call_mock.call_args.kwargs["payload"]
assert payload["application_id"] == "test-app"
@pytest.mark.asyncio
async def test_apply_guardrail_omits_application_id_when_no_fallback_available(
self,
):
"""When nothing is set — no config, no dynamic params, no key alias — omit entirely."""
guardrail_no_config = NomaV2Guardrail(
api_key="test-api-key",
application_id=None,
@@ -490,7 +602,6 @@ class TestNomaV2ApplicationIdResolution:
event_hook="pre_call",
default_on=True,
)
call_mock = AsyncMock(return_value={"action": "NONE"})
with patch.object(
guardrail_no_config,
@@ -506,26 +617,3 @@ class TestNomaV2ApplicationIdResolution:
payload = call_mock.call_args.kwargs["payload"]
assert "application_id" not in payload
@pytest.mark.asyncio
async def test_apply_guardrail_ignores_request_metadata_application_id(self, noma_v2_guardrail):
noma_v2_guardrail.application_id = None
call_mock = AsyncMock(return_value={"action": "NONE"})
request_data = {
"metadata": {"headers": {"x-noma-application-id": "header-app"}},
"litellm_metadata": {"user_api_key_alias": "alias-app"},
}
with patch.object(
noma_v2_guardrail,
"get_guardrail_dynamic_request_body_params",
return_value={},
):
with patch.object(noma_v2_guardrail, "_call_noma_scan", call_mock):
await noma_v2_guardrail.apply_guardrail(
inputs={"texts": ["hello"]},
request_data=request_data,
input_type="request",
)
payload = call_mock.call_args.kwargs["payload"]
assert "application_id" not in payload
@@ -8,7 +8,7 @@ sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.proxy._types import KeyManagementRoutes, Member
from litellm.proxy._types import KeyManagementRoutes, Member, ProxyException
from litellm.proxy.management_helpers.team_member_permission_checks import (
BASELINE_TEAM_MEMBER_PERMISSIONS,
TeamMemberPermissionChecks,
@@ -188,3 +188,74 @@ class TestGetDefaultTeamParam:
assert _get_default_team_param("budget_duration") == "7d"
assert _get_default_team_param("tpm_limit") == 1000
assert _get_default_team_param("rpm_limit") == 100
class TestCanTeamMemberExecuteKeyManagementEndpoint:
@pytest.mark.asyncio
async def test_raises_when_user_not_in_keys_team(self, monkeypatch):
"""Non-members should be blocked from team-scoped key management endpoints."""
from litellm.proxy.management_endpoints import key_management_endpoints
from litellm.proxy.management_helpers import team_member_permission_checks as module
async def _mock_get_team_object(**kwargs):
team = MagicMock()
team.team_id = "team-b"
team.team_member_permissions = ["/key/update"]
return team
monkeypatch.setattr(module, "get_team_object", _mock_get_team_object)
monkeypatch.setattr(key_management_endpoints, "_get_user_in_team", lambda **kwargs: None)
user_api_key_dict = MagicMock()
user_api_key_dict.user_role = "internal_user"
user_api_key_dict.user_id = "user-a"
user_api_key_dict.parent_otel_span = None
existing_key_row = MagicMock()
existing_key_row.team_id = "team-b"
with pytest.raises(ProxyException) as exc:
await TeamMemberPermissionChecks.can_team_member_execute_key_management_endpoint(
user_api_key_dict=user_api_key_dict,
route=KeyManagementRoutes.KEY_UPDATE,
prisma_client=MagicMock(),
user_api_key_cache=MagicMock(),
existing_key_row=existing_key_row,
)
assert str(exc.value.code) == "401"
assert exc.value.type == "team_member_permission_error"
@pytest.mark.asyncio
async def test_allows_team_admin_in_keys_team(self, monkeypatch):
"""Team admins of the key's team should be allowed."""
from litellm.proxy.management_endpoints import key_management_endpoints
from litellm.proxy.management_helpers import team_member_permission_checks as module
async def _mock_get_team_object(**kwargs):
team = MagicMock()
team.team_id = "team-a"
team.team_member_permissions = ["/key/update"]
return team
monkeypatch.setattr(module, "get_team_object", _mock_get_team_object)
monkeypatch.setattr(
key_management_endpoints,
"_get_user_in_team",
lambda **kwargs: Member(role="admin", user_id="user-a"),
)
user_api_key_dict = MagicMock()
user_api_key_dict.user_role = "internal_user"
user_api_key_dict.user_id = "user-a"
user_api_key_dict.parent_otel_span = None
existing_key_row = MagicMock()
existing_key_row.team_id = "team-a"
await TeamMemberPermissionChecks.can_team_member_execute_key_management_endpoint(
user_api_key_dict=user_api_key_dict,
route=KeyManagementRoutes.KEY_UPDATE,
prisma_client=MagicMock(),
user_api_key_cache=MagicMock(),
existing_key_row=existing_key_row,
)
@@ -0,0 +1,141 @@
"""
Tests for CORS configuration security fix.
All tests import _get_cors_config directly from proxy_server so they exercise
real production code rather than a local mirror.
"""
import pytest
def test_cors_wildcard_disables_credentials():
"""should disable credentials when LITELLM_CORS_ORIGINS is not set (defaults to wildcard)."""
from litellm.proxy.proxy_server import _get_cors_config
origins, allow_credentials = _get_cors_config(cors_origins_env="")
assert origins == ["*"]
assert allow_credentials is False
def test_cors_empty_string_disables_credentials():
"""should disable credentials when LITELLM_CORS_ORIGINS is empty or whitespace."""
from litellm.proxy.proxy_server import _get_cors_config
for empty in ("", " ", "\t"):
origins, allow_credentials = _get_cors_config(cors_origins_env=empty)
assert origins == ["*"], f"Expected wildcard for input {repr(empty)}"
assert (
allow_credentials is False
), f"Expected no credentials for input {repr(empty)}"
def test_cors_single_specific_origin_enables_credentials():
"""should enable credentials when a single explicit origin is configured."""
from litellm.proxy.proxy_server import _get_cors_config
origins, allow_credentials = _get_cors_config(
cors_origins_env="https://admin.example.com"
)
assert origins == ["https://admin.example.com"]
assert allow_credentials is True
def test_cors_multiple_specific_origins_enables_credentials():
"""should enable credentials and correctly parse comma-separated origins."""
from litellm.proxy.proxy_server import _get_cors_config
origins, allow_credentials = _get_cors_config(
cors_origins_env="https://app.example.com, https://admin.example.com, https://api.example.com"
)
assert origins == [
"https://app.example.com",
"https://admin.example.com",
"https://api.example.com",
]
assert allow_credentials is True
def test_cors_wildcard_string_in_env_disables_credentials():
"""should disable credentials when LITELLM_CORS_ORIGINS is explicitly set to '*'."""
from litellm.proxy.proxy_server import _get_cors_config
origins, allow_credentials = _get_cors_config(cors_origins_env="*")
assert "*" in origins
assert allow_credentials is False
def test_cors_origins_strips_whitespace():
"""should strip surrounding whitespace from each origin entry."""
from litellm.proxy.proxy_server import _get_cors_config
origins, _ = _get_cors_config(
cors_origins_env=" https://a.com , https://b.com "
)
assert origins == ["https://a.com", "https://b.com"]
def test_cors_origins_skips_blank_entries():
"""should skip blank entries caused by trailing/double commas."""
from litellm.proxy.proxy_server import _get_cors_config
origins, allow_credentials = _get_cors_config(
cors_origins_env="https://a.com,,https://b.com,"
)
assert origins == ["https://a.com", "https://b.com"]
assert allow_credentials is True
def test_cors_explicit_credentials_true_overrides_wildcard():
"""should enable credentials when LITELLM_CORS_ALLOW_CREDENTIALS=true even
if wildcard origins are in use (opt-in for existing deployments)."""
from litellm.proxy.proxy_server import _get_cors_config
origins, allow_credentials = _get_cors_config(
cors_origins_env="",
cors_credentials_env="true",
)
assert "*" in origins
assert allow_credentials is True
def test_cors_explicit_credentials_false_overrides_specific_origins():
"""should disable credentials when LITELLM_CORS_ALLOW_CREDENTIALS=false even
if specific origins are configured."""
from litellm.proxy.proxy_server import _get_cors_config
origins, allow_credentials = _get_cors_config(
cors_origins_env="https://admin.example.com",
cors_credentials_env="false",
)
assert origins == ["https://admin.example.com"]
assert allow_credentials is False
def test_cors_explicit_credentials_case_insensitive():
"""should accept TRUE/FALSE case-insensitively for LITELLM_CORS_ALLOW_CREDENTIALS."""
from litellm.proxy.proxy_server import _get_cors_config
_, allow_true = _get_cors_config(cors_origins_env="", cors_credentials_env="TRUE")
_, allow_false = _get_cors_config(
cors_origins_env="https://x.com", cors_credentials_env="FALSE"
)
assert allow_true is True
assert allow_false is False
def test_proxy_server_cors_invariant():
"""should verify that proxy_server module-level origins and allow_cors_credentials
are consistent catches any future drift in the module-level call to _get_cors_config.
"""
import os
import litellm.proxy.proxy_server as proxy_server
if os.getenv("LITELLM_CORS_ALLOW_CREDENTIALS") is None:
assert proxy_server.allow_cors_credentials == (
"*" not in proxy_server.origins
), (
f"Invariant broken: allow_cors_credentials={proxy_server.allow_cors_credentials} "
f"but origins={proxy_server.origins}. "
"When origins contains '*', allow_credentials must be False."
)
@@ -406,12 +406,6 @@ async def test_save_background_health_checks_to_db_exception_handling():
@pytest.mark.asyncio
async def test_get_all_latest_health_checks_with_model_id(mock_prisma):
"""Test get_all_latest_health_checks properly groups by model_id"""
# Create mock checks with same model_name but different model_id
mock_check1 = MagicMock()
mock_check1.model_id = "model-123"
mock_check1.model_name = "gpt-3.5-turbo"
mock_check1.checked_at = datetime.now(timezone.utc) - timedelta(minutes=10)
mock_check2 = MagicMock()
mock_check2.model_id = "model-456"
mock_check2.model_name = "gpt-3.5-turbo"
@@ -424,7 +418,7 @@ async def test_get_all_latest_health_checks_with_model_id(mock_prisma):
# Order by checked_at desc
mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock(
return_value=[mock_check3, mock_check2, mock_check1]
return_value=[mock_check3, mock_check2]
)
result = await mock_prisma.get_all_latest_health_checks()
@@ -445,18 +439,13 @@ async def test_get_all_latest_health_checks_with_model_id(mock_prisma):
@pytest.mark.asyncio
async def test_get_all_latest_health_checks_without_model_id(mock_prisma):
"""Test get_all_latest_health_checks groups by model_name when model_id is None"""
mock_check1 = MagicMock()
mock_check1.model_id = None
mock_check1.model_name = "gpt-3.5-turbo"
mock_check1.checked_at = datetime.now(timezone.utc) - timedelta(minutes=10)
mock_check2 = MagicMock()
mock_check2.model_id = None
mock_check2.model_name = "gpt-3.5-turbo"
mock_check2.checked_at = datetime.now(timezone.utc) - timedelta(minutes=1) # Latest
mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock(
return_value=[mock_check2, mock_check1]
return_value=[mock_check2]
)
result = await mock_prisma.get_all_latest_health_checks()
@@ -467,6 +456,41 @@ async def test_get_all_latest_health_checks_without_model_id(mock_prisma):
assert result[0].checked_at == mock_check2.checked_at # Latest
@pytest.mark.asyncio
async def test_get_all_latest_health_checks_same_name_with_and_without_model_id(mock_prisma):
"""
Same model_name can appear twice after DISTINCT ON: once keyed by (model_id, name)
and once by (NULL, name) different Postgres groups than a single row with id.
"""
now = datetime.now(timezone.utc)
with_id = MagicMock()
with_id.model_id = "deployment-abc"
with_id.model_name = "gpt-4"
with_id.checked_at = now - timedelta(minutes=2)
without_id = MagicMock()
without_id.model_id = None
without_id.model_name = "gpt-4"
without_id.checked_at = now - timedelta(minutes=1)
mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock(
return_value=[without_id, with_id]
)
result = await mock_prisma.get_all_latest_health_checks()
assert len(result) == 2
names = {r.model_name for r in result}
assert names == {"gpt-4"}
ids = {r.model_id for r in result}
assert "deployment-abc" in ids
assert None in ids
by_key = {(r.model_id, r.model_name): r for r in result}
assert by_key[("deployment-abc", "gpt-4")].checked_at == with_id.checked_at
assert by_key[(None, "gpt-4")].checked_at == without_id.checked_at
@pytest.mark.asyncio
async def test_perform_health_check_and_save_passes_model_id_to_perform_health_check():
"""Test that _perform_health_check_and_save passes model_id to perform_health_check so health checks run by model id."""
@@ -1,7 +1,10 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from litellm.proxy.health_check import _update_litellm_params_for_health_check
from litellm.litellm_core_utils.health_check_helpers import HealthCheckHelpers
from unittest.mock import AsyncMock, patch, MagicMock
from litellm.proxy import health_check as hc_module
from litellm.proxy.health_check import _update_litellm_params_for_health_check
@pytest.mark.asyncio
@@ -50,10 +53,13 @@ async def test_ahealth_check_wildcard_models_respects_max_tokens():
Test that ahealth_check_wildcard_models respects max_tokens if passed,
otherwise defaults to 10.
"""
with patch(
"litellm.litellm_core_utils.llm_request_utils.pick_cheapest_chat_models_from_llm_provider",
return_value=["gpt-4o-mini"],
), patch("litellm.acompletion", new_callable=AsyncMock):
with (
patch(
"litellm.litellm_core_utils.llm_request_utils.pick_cheapest_chat_models_from_llm_provider",
return_value=["gpt-4o-mini"],
),
patch("litellm.acompletion", new_callable=AsyncMock),
):
# Test Case 1: No max_tokens passed, should default to 10
model_params = {}
await HealthCheckHelpers.ahealth_check_wildcard_models(
@@ -73,3 +79,50 @@ async def test_ahealth_check_wildcard_models_respects_max_tokens():
litellm_logging_obj=MagicMock(),
)
assert model_params["max_tokens"] == 3
@pytest.mark.asyncio
async def test_background_health_check_max_tokens_env_var(monkeypatch):
"""
Test that BACKGROUND_HEALTH_CHECK_MAX_TOKENS env var is used as global default
for explicit (non-wildcard) models.
"""
monkeypatch.setattr(hc_module, "BACKGROUND_HEALTH_CHECK_MAX_TOKENS", 10)
model_info = {}
litellm_params = {"model": "azure/gpt-4"}
updated_params = _update_litellm_params_for_health_check(model_info, litellm_params)
assert updated_params["max_tokens"] == 10
@pytest.mark.asyncio
async def test_per_model_overrides_global_env_var(monkeypatch):
"""
Test that per-model health_check_max_tokens takes priority over
BACKGROUND_HEALTH_CHECK_MAX_TOKENS env var.
"""
monkeypatch.setattr(hc_module, "BACKGROUND_HEALTH_CHECK_MAX_TOKENS", 10)
model_info = {"health_check_max_tokens": 5}
litellm_params = {"model": "azure/gpt-4"}
updated_params = _update_litellm_params_for_health_check(model_info, litellm_params)
assert updated_params["max_tokens"] == 5
@pytest.mark.asyncio
async def test_global_env_var_applies_to_wildcard_models(monkeypatch):
"""
Test that BACKGROUND_HEALTH_CHECK_MAX_TOKENS env var also applies to wildcard models.
"""
monkeypatch.setattr(hc_module, "BACKGROUND_HEALTH_CHECK_MAX_TOKENS", 15)
model_info = {}
litellm_params = {"model": "openai/*"}
updated_params = _update_litellm_params_for_health_check(model_info, litellm_params)
assert updated_params["max_tokens"] == 15
@@ -287,9 +287,48 @@ def test_string_retention_still_works():
general_settings={"maximum_spend_logs_retention_period": setting}
)
assert cleaner._should_delete_spend_logs() is True, f"Failed for {setting}"
assert cleaner.retention_seconds == expected_seconds, (
f"Expected {expected_seconds} for {setting}, got {cleaner.retention_seconds}"
)
assert (
cleaner.retention_seconds == expected_seconds
), f"Expected {expected_seconds} for {setting}, got {cleaner.retention_seconds}"
@pytest.mark.asyncio
async def test_delete_old_logs_aborts_on_non_int_execute_raw_return():
"""should abort deletion loop immediately when execute_raw returns a non-int
(e.g. None or dict), preventing an infinite loop."""
mock_prisma_client = MagicMock()
mock_db = MagicMock()
mock_db.execute_raw = AsyncMock(return_value=None)
mock_prisma_client.db = mock_db
cleaner = SpendLogCleanup(
general_settings={"maximum_spend_logs_retention_period": "7d"}
)
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
total_deleted = await cleaner._delete_old_logs(mock_prisma_client, cutoff_date)
assert mock_db.execute_raw.call_count == 1
assert total_deleted == 0
@pytest.mark.asyncio
async def test_delete_old_logs_continues_on_valid_int_return():
"""should continue deletion loop across batches when execute_raw returns valid int counts."""
mock_prisma_client = MagicMock()
mock_db = MagicMock()
mock_db.execute_raw = AsyncMock(side_effect=[500, 300, 0])
mock_prisma_client.db = mock_db
cleaner = SpendLogCleanup(
general_settings={"maximum_spend_logs_retention_period": "7d"}
)
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
total_deleted = await cleaner._delete_old_logs(mock_prisma_client, cutoff_date)
assert mock_db.execute_raw.call_count == 3
assert total_deleted == 800
def test_cleanup_batch_size_env_var(monkeypatch):
@@ -0,0 +1,145 @@
"""Unit tests for litellm.litellm_core_utils.completion_timeout.CompletionTimeout."""
import os
import sys
import httpx
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
)
from litellm.litellm_core_utils.completion_timeout import CompletionTimeout
from litellm.utils import supports_httpx_timeout
def test_explicit_timeout_wins():
assert (
CompletionTimeout.resolve(
12.5,
{"timeout": 99.0, "request_timeout": 88.0},
"openai",
global_timeout=None,
supports_httpx_timeout=supports_httpx_timeout,
)
== 12.5
)
def test_kwargs_timeout_when_param_none():
assert (
CompletionTimeout.resolve(
None,
{"timeout": 21.0},
"azure_ai",
global_timeout=None,
supports_httpx_timeout=supports_httpx_timeout,
)
== 21.0
)
def test_request_timeout_alias_in_kwargs():
assert (
CompletionTimeout.resolve(
None,
{"request_timeout": 33.0},
"bedrock",
global_timeout=None,
supports_httpx_timeout=supports_httpx_timeout,
)
== 33.0
)
def test_global_timeout_from_litellm_settings():
assert (
CompletionTimeout.resolve(
None,
{},
"vertex_ai",
global_timeout=360.0,
supports_httpx_timeout=supports_httpx_timeout,
)
== 360.0
)
def test_global_timeout_package_default_coerced_to_600_for_completion():
"""Package default 6000s → 600s for completion-only path."""
assert (
CompletionTimeout.resolve(
None,
{},
"openai",
global_timeout=6000.0,
supports_httpx_timeout=supports_httpx_timeout,
)
== 600.0
)
def test_explicit_request_timeout_6000_preserved():
"""Explicit deployment/request timeout must not be truncated by the package sentinel."""
assert (
CompletionTimeout.resolve(
None,
{"request_timeout": 6000.0},
"openai",
global_timeout=None,
supports_httpx_timeout=supports_httpx_timeout,
)
== 6000.0
)
def test_explicit_model_timeout_6000_preserved():
assert (
CompletionTimeout.resolve(
6000.0,
{"timeout": 1.0, "request_timeout": 2.0},
"openai",
global_timeout=None,
supports_httpx_timeout=supports_httpx_timeout,
)
== 6000.0
)
def test_fallback_600_when_no_global_timeout():
assert (
CompletionTimeout.resolve(
None,
{},
"azure_ai",
global_timeout=None,
supports_httpx_timeout=supports_httpx_timeout,
)
== 600.0
)
def test_httpx_timeout_coerced_for_provider_without_httpx_timeout_support():
t = httpx.Timeout(50.0, connect=2.0)
out = CompletionTimeout.resolve(
t,
{},
"azure_ai",
global_timeout=None,
supports_httpx_timeout=supports_httpx_timeout,
)
assert out == 50.0
assert not isinstance(out, httpx.Timeout)
def test_httpx_timeout_preserved_for_openai():
t = httpx.Timeout(40.0, connect=5.0)
out = CompletionTimeout.resolve(
t,
{},
"openai",
global_timeout=None,
supports_httpx_timeout=supports_httpx_timeout,
)
assert out is t
assert isinstance(out, httpx.Timeout)
+25 -1
View File
@@ -1876,7 +1876,7 @@ def test_gemini_without_cache_tokens_details():
"promptTokensDetails": [
{"modality": "TEXT", "tokenCount": 6},
{"modality": "IMAGE", "tokenCount": 258},
]
],
# No cacheTokensDetails
}
}
@@ -2014,3 +2014,27 @@ def test_additional_costs_only_for_azure_ai():
completion_tokens=50,
)
assert result is None, "Vertex AI should have no additional costs"
def test_openrouter_gemini_3_1_flash_lite_preview_pricing():
"""
Test that openrouter/google/gemini-3.1-flash-lite-preview has a pricing entry.
Regression test for https://github.com/BerriAI/litellm/issues/25604
The model exists and is callable via OpenRouter, but was missing from
model_prices_and_context_window.json when other Gemini 3.x variants were present.
This caused ValueError: This model isn't mapped yet during router pre-call checks.
"""
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
model_name = "openrouter/google/gemini-3.1-flash-lite-preview"
model_info = litellm.model_cost.get(model_name)
assert model_info is not None, f"Missing model pricing entry: {model_name}"
assert model_info["litellm_provider"] == "openrouter"
assert model_info["input_cost_per_token"] == 2.5e-07
assert model_info["output_cost_per_token"] == 1.5e-06
assert model_info["max_input_tokens"] == 1048576
assert model_info["max_output_tokens"] == 65536
@@ -31,3 +31,32 @@ def test_get_standard_logging_model_parameters_excludes_prompt_content():
assert "prompt" not in result
assert "input" not in result
assert result == {"temperature": 0.5}
def test_get_all_llm_api_params_includes_responses_api():
"""
Regression guard for the Responses API cache-key bug:
Responses-API-only kwargs must be present in the cache-key allow-list,
otherwise Cache.get_cache_key() silently drops them and two requests
that differ only in (e.g.) `instructions` collide on the same key.
"""
all_params = ModelParamHelper._get_all_llm_api_params()
responses_only_params = {
"instructions",
"previous_response_id",
"reasoning",
"include",
"store",
"background",
"max_output_tokens",
"max_tool_calls",
"prompt_cache_key",
"prompt_cache_retention",
"context_management",
"conversation",
"safety_identifier",
}
missing = responses_only_params - all_params
assert (
missing == set()
), f"Responses-API kwargs missing from cache-key allow-list: {sorted(missing)}"
@@ -3,7 +3,7 @@ import { Organization, organizationInfoCall, organizationListCall } from "@/comp
import { useQuery, useQueryClient, UseQueryResult } from "@tanstack/react-query";
import { createQueryKeys } from "../common/queryKeysFactory";
const organizationKeys = createQueryKeys("organizations");
export const organizationKeys = createQueryKeys("organizations");
export const useOrganizations = (): UseQueryResult<Organization[]> => {
const { accessToken, userId, userRole } = useAuthorized();
return useQuery<Organization[]>({
@@ -1,4 +1,6 @@
import React, { useState, useEffect } from "react";
import { useQueryClient } from "@tanstack/react-query";
import { organizationKeys } from "@/app/(dashboard)/hooks/organizations/useOrganizations";
import { teamDeleteCall, Organization } from "@/components/networking";
import { fetchTeams } from "@/components/common_components/fetch_teams";
import { Form } from "antd";
@@ -54,6 +56,7 @@ const TeamsView: React.FC<TeamProps> = ({
organizations,
premiumUser = false,
}) => {
const queryClient = useQueryClient();
const [currentOrg, setCurrentOrg] = useState<Organization | null>(null);
const [showFilters, setShowFilters] = useState(false);
const [filters, setFilters] = useState<FilterState>({
@@ -138,6 +141,7 @@ const TeamsView: React.FC<TeamProps> = ({
try {
await teamDeleteCall(accessToken, teamToDelete);
queryClient.invalidateQueries({ queryKey: organizationKeys.all });
// Successfully completed the deletion. Update the state to trigger a rerender.
fetchTeams(accessToken, userID, userRole, currentOrg, setTeams);
} catch (error) {
@@ -13,9 +13,11 @@ import AgentSelector from "@/components/agent_management/AgentSelector";
import PremiumLoggingSettings from "@/components/common_components/PremiumLoggingSettings";
import ModelAliasManager from "@/components/common_components/ModelAliasManager";
import React, { useEffect, useState } from "react";
import { useQueryClient } from "@tanstack/react-query";
import NotificationsManager from "@/components/molecules/notifications_manager";
import { fetchMCPAccessGroups, getGuardrailsList, getPoliciesList, Organization, Team, teamCreateCall } from "@/components/networking";
import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized";
import { organizationKeys } from "@/app/(dashboard)/hooks/organizations/useOrganizations";
import MCPToolPermissions from "@/components/mcp_server_management/MCPToolPermissions";
interface ModelAliases {
@@ -71,6 +73,7 @@ const CreateTeamModal = ({
setIsTeamModalVisible,
}: CreateTeamModalProps) => {
const { userId: userID, userRole, accessToken, premiumUser } = useAuthorized();
const queryClient = useQueryClient();
const [form] = Form.useForm();
const [userModels, setUserModels] = useState<string[]>([]);
const [currentOrgForCreateTeam, setCurrentOrgForCreateTeam] = useState<Organization | null>(null);
@@ -273,6 +276,7 @@ const CreateTeamModal = ({
}
const response: any = await teamCreateCall(accessToken, formValues);
queryClient.invalidateQueries({ queryKey: organizationKeys.all });
if (teams !== null) {
setTeams([...teams, response]);
} else {
@@ -192,8 +192,8 @@ const GuardrailOptionalParams: React.FC<GuardrailOptionalParamsProps> = ({
</Select>
) : field.type === "bool" || field.type === "boolean" ? (
<Select placeholder={field.description}>
<Select.Option value="true">True</Select.Option>
<Select.Option value="false">False</Select.Option>
<Select.Option value={true}>True</Select.Option>
<Select.Option value={false}>False</Select.Option>
</Select>
) : field.type === "number" ? (
<NumericalInput step={1} width={400} placeholder={field.description} />
@@ -89,7 +89,7 @@ describe("ToolTestPanel defaults", () => {
expect(screen.getByLabelText("message")).toHaveValue("");
expect(screen.getByLabelText("attempts")).toHaveValue(0);
expect(screen.getByLabelText("ratio")).toHaveValue(0.4);
expect(screen.getByDisplayValue("True")).toBeInTheDocument();
expect(screen.getByTitle("True")).toBeInTheDocument();
const keywordsTextarea = screen.getByTestId("textarea-keywords");
expect(JSON.parse(keywordsTextarea.value)).toEqual([""]);
@@ -1,7 +1,7 @@
import React from "react";
import { Button, TextInput } from "@tremor/react";
import { MCPTool, InputSchema, InputSchemaProperty } from "./types";
import { Form, Tooltip } from "antd";
import { Form, Select, Tooltip } from "antd";
import { InfoCircleOutlined } from "@ant-design/icons";
import NotificationsManager from "../molecules/notifications_manager";
@@ -480,14 +480,14 @@ export function ToolTestPanel({
)}
{prop.type === "boolean" && (
<select
className="w-full px-3 py-2 border border-gray-300 rounded-lg shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm transition-colors"
defaultValue={(initialValue ?? false).toString()}
<Select
placeholder={`Select ${key}`}
allowClear={!actualSchema.required?.includes(key)}
className="w-full"
>
{!actualSchema.required?.includes(key) && <option value="">Select {key}</option>}
<option value="true">True</option>
<option value="false">False</option>
</select>
<Select.Option value={true}>True</Select.Option>
<Select.Option value={false}>False</Select.Option>
</Select>
)}
{(prop.type === "object" || prop.type === "array") && (
@@ -1,14 +1,15 @@
import React from "react";
import { render, screen, waitFor } from "@testing-library/react";
import { screen, waitFor } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { vi, test, expect } from "vitest";
import { vi, test, expect, beforeEach } from "vitest";
import { renderWithProviders } from "../../../tests/test-utils";
import OrganizationInfoView from "./organization_view";
import { useOrganization } from "@/app/(dashboard)/hooks/organizations/useOrganizations";
// Mock networking calls used by the component
// Mock networking calls used by the component's mutation handlers
vi.mock("../networking", () => {
return {
__esModule: true,
organizationInfoCall: vi.fn(),
organizationMemberAddCall: vi.fn(),
organizationMemberUpdateCall: vi.fn(),
organizationMemberDeleteCall: vi.fn(),
@@ -16,6 +17,20 @@ vi.mock("../networking", () => {
};
});
// Mock the React Query hook the component now reads org data from. The component
// also imports organizationKeys (used inside mutation handlers for invalidation),
// so provide a stub shape here too.
vi.mock("@/app/(dashboard)/hooks/organizations/useOrganizations", () => ({
useOrganization: vi.fn(),
organizationKeys: {
all: ["organizations"],
list: () => ["organizations", "list", { params: {} }],
detail: (id: string) => ["organizations", "detail", id],
},
}));
const mockUseOrganization = vi.mocked(useOrganization);
// Mock noisy/heavy child components to keep this test focused on render
vi.mock("../object_permissions_view", () => ({
__esModule: true,
@@ -81,11 +96,14 @@ const mockOrg = {
metadata: null,
};
test("renders organization view after loading data", async () => {
const { organizationInfoCall } = await import("../networking");
(organizationInfoCall as unknown as ReturnType<typeof vi.fn>).mockResolvedValueOnce(mockOrg);
beforeEach(() => {
mockUseOrganization.mockReset();
});
const { findAllByText } = render(
test("renders organization view after loading data", async () => {
mockUseOrganization.mockReturnValue({ data: mockOrg, isLoading: false } as any);
const { findAllByText } = renderWithProviders(
<OrganizationInfoView
organizationId="org_123"
onClose={() => {}}
@@ -103,11 +121,10 @@ test("renders organization view after loading data", async () => {
});
test("should display empty state when organization has no members", async () => {
const { organizationInfoCall } = await import("../networking");
(organizationInfoCall as unknown as ReturnType<typeof vi.fn>).mockResolvedValueOnce(mockOrg);
mockUseOrganization.mockReturnValue({ data: mockOrg, isLoading: false } as any);
const user = userEvent.setup();
render(
renderWithProviders(
<OrganizationInfoView
organizationId="org_123"
onClose={() => {}}
@@ -131,14 +148,13 @@ test("should display empty state when organization has no members", async () =>
});
test("should display team aliases when teams are available", async () => {
const { organizationInfoCall } = await import("../networking");
const orgWithTeams = {
...mockOrg,
teams: [{ team_id: "team_123" }, { team_id: "team_456" }],
};
(organizationInfoCall as unknown as ReturnType<typeof vi.fn>).mockResolvedValueOnce(orgWithTeams);
mockUseOrganization.mockReturnValue({ data: orgWithTeams, isLoading: false } as any);
render(
renderWithProviders(
<OrganizationInfoView
organizationId="org_123"
onClose={() => {}}
@@ -157,7 +173,6 @@ test("should display team aliases when teams are available", async () => {
});
test("should display team ID as fallback when alias is not found", async () => {
const { organizationInfoCall } = await import("../networking");
mockUseTeams.mockReturnValueOnce({
data: [
{
@@ -171,9 +186,9 @@ test("should display team ID as fallback when alias is not found", async () => {
...mockOrg,
teams: [{ team_id: "team_999" }],
};
(organizationInfoCall as unknown as ReturnType<typeof vi.fn>).mockResolvedValueOnce(orgWithUnknownTeam);
mockUseOrganization.mockReturnValue({ data: orgWithUnknownTeam, isLoading: false } as any);
render(
renderWithProviders(
<OrganizationInfoView
organizationId="org_123"
onClose={() => {}}
@@ -1,4 +1,6 @@
import { useTeams } from "@/app/(dashboard)/hooks/teams/useTeams";
import { organizationKeys, useOrganization } from "@/app/(dashboard)/hooks/organizations/useOrganizations";
import { useQueryClient } from "@tanstack/react-query";
import { formatNumberWithCommas, copyToClipboard as utilCopyToClipboard } from "@/utils/dataUtils";
import { createTeamAliasMap } from "@/utils/teamUtils";
import { ArrowLeftIcon } from "@heroicons/react/outline";
@@ -14,7 +16,7 @@ import {
import { Button, Form, Input, Select, Tabs, Typography } from "antd";
import type { ColumnsType } from "antd/es/table";
import { CheckIcon, CopyIcon } from "lucide-react";
import React, { useEffect, useMemo, useState } from "react";
import React, { useMemo, useState } from "react";
import MemberTable from "../common_components/MemberTable";
import UserSearchModal from "../common_components/user_search_modal";
import MCPServerSelector from "../mcp_server_management/MCPServerSelector";
@@ -23,7 +25,6 @@ import NotificationsManager from "../molecules/notifications_manager";
import {
Member,
Organization,
organizationInfoCall,
organizationMemberAddCall,
organizationMemberDeleteCall,
organizationMemberUpdateCall,
@@ -53,8 +54,8 @@ const OrganizationInfoView: React.FC<OrganizationInfoProps> = ({
userModels,
editOrg,
}) => {
const [orgData, setOrgData] = useState<Organization | null>(null);
const [loading, setLoading] = useState(true);
const queryClient = useQueryClient();
const { data: orgData, isLoading: loading } = useOrganization(organizationId);
const [form] = Form.useForm();
const [isEditing, setIsEditing] = useState(false);
const [isAddMemberModalVisible, setIsAddMemberModalVisible] = useState(false);
@@ -67,24 +68,6 @@ const OrganizationInfoView: React.FC<OrganizationInfoProps> = ({
const teamAliasMap = useMemo(() => createTeamAliasMap(teams), [teams]);
const fetchOrgInfo = async () => {
try {
setLoading(true);
if (!accessToken) return;
const response = await organizationInfoCall(accessToken, organizationId);
setOrgData(response);
} catch (error) {
NotificationsManager.fromBackend("Failed to load organization information");
console.error("Error fetching organization info:", error);
} finally {
setLoading(false);
}
};
useEffect(() => {
fetchOrgInfo();
}, [organizationId, accessToken]);
const handleMemberAdd = async (values: any) => {
try {
if (accessToken == null) {
@@ -101,7 +84,7 @@ const OrganizationInfoView: React.FC<OrganizationInfoProps> = ({
NotificationsManager.success("Organization member added successfully");
setIsAddMemberModalVisible(false);
form.resetFields();
fetchOrgInfo();
queryClient.invalidateQueries({ queryKey: organizationKeys.all });
} catch (error) {
NotificationsManager.fromBackend("Failed to add organization member");
console.error("Error adding organization member:", error);
@@ -122,7 +105,7 @@ const OrganizationInfoView: React.FC<OrganizationInfoProps> = ({
NotificationsManager.success("Organization member updated successfully");
setIsEditMemberModalVisible(false);
form.resetFields();
fetchOrgInfo();
queryClient.invalidateQueries({ queryKey: organizationKeys.all });
} catch (error) {
NotificationsManager.fromBackend("Failed to update organization member");
console.error("Error updating organization member:", error);
@@ -137,7 +120,7 @@ const OrganizationInfoView: React.FC<OrganizationInfoProps> = ({
NotificationsManager.success("Organization member deleted successfully");
setIsEditMemberModalVisible(false);
form.resetFields();
fetchOrgInfo();
queryClient.invalidateQueries({ queryKey: organizationKeys.all });
} catch (error) {
NotificationsManager.fromBackend("Failed to delete organization member");
console.error("Error deleting organization member:", error);
@@ -187,7 +170,7 @@ const OrganizationInfoView: React.FC<OrganizationInfoProps> = ({
NotificationsManager.success("Organization settings updated successfully");
setIsEditing(false);
fetchOrgInfo();
queryClient.invalidateQueries({ queryKey: organizationKeys.all });
} catch (error) {
NotificationsManager.fromBackend("Failed to update organization settings");
console.error("Error updating organization:", error);
@@ -0,0 +1,220 @@
import React from "react";
import { screen, waitFor, within } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { renderWithProviders } from "../../../tests/test-utils";
import { beforeEach, describe, expect, it, vi } from "vitest";
import PoliciesPanel from "./index";
/**
* Ant Design's static Modal.confirm often does not run onOk in the real app (React 18+).
* In jsdom it may still run; we mock confirm as a no-op so the test fails until the panel
* uses a controlled DeleteResourceModal instead of Modal.confirm.
*/
vi.mock("antd", async (importOriginal) => {
const mod = await importOriginal<typeof import("antd")>();
return {
...mod,
Modal: Object.assign(mod.Modal, {
confirm: vi.fn(),
}),
};
});
const EXPECTED_ATTACHMENT_ID = "att-11111111-2222-3333-4444-555555555555" as const;
const networkingMocks = vi.hoisted(() => ({
deletePolicyAttachmentCall: vi.fn().mockResolvedValue(undefined),
getPoliciesList: vi.fn().mockResolvedValue({ policies: [] }),
getPolicyAttachmentsList: vi.fn().mockResolvedValue({
attachments: [
{
attachment_id: "att-11111111-2222-3333-4444-555555555555",
policy_name: "test-policy",
scope: null,
teams: [],
keys: [],
models: [],
tags: [],
},
],
}),
getGuardrailsList: vi.fn().mockResolvedValue({ guardrails: [] }),
getPolicyInfo: vi.fn().mockResolvedValue({}),
deletePolicyCall: vi.fn().mockResolvedValue(undefined),
createPolicyCall: vi.fn(),
updatePolicyCall: vi.fn(),
createPolicyAttachmentCall: vi.fn(),
createGuardrailCall: vi.fn(),
enrichPolicyTemplate: vi.fn(),
}));
vi.mock("../networking", () => ({
...networkingMocks,
}));
vi.mock("./impact_popover", () => ({
default: () => <button type="button" aria-label="View blast radius" />,
}));
vi.mock("@heroicons/react/outline", () => ({
TrashIcon: function TrashIcon() {
return null;
},
SwitchVerticalIcon: function SwitchVerticalIcon() {
return null;
},
ChevronUpIcon: function ChevronUpIcon() {
return null;
},
ChevronDownIcon: function ChevronDownIcon() {
return null;
},
}));
vi.mock("@tremor/react", async (importOriginal) => {
const actual = await importOriginal<typeof import("@tremor/react")>();
return {
...actual,
Button: React.forwardRef<HTMLButtonElement, any>(({ children, ...props }, ref) =>
React.createElement("button", { ...props, ref }, children),
),
Tooltip: ({ children }: { children?: React.ReactNode }) =>
React.createElement(React.Fragment, null, children),
Switch: ({
checked,
onChange,
className,
}: {
checked?: boolean;
onChange?: (v: boolean) => void;
className?: string;
}) =>
React.createElement("input", {
type: "checkbox",
role: "switch",
checked,
onChange: (e: React.ChangeEvent<HTMLInputElement>) => onChange?.(e.target.checked),
className,
}),
Icon: ({ icon: _IconComp, onClick, className }: any) =>
React.createElement(
"button",
{ type: "button", onClick, className },
"TrashIcon",
),
};
});
vi.mock("./policy_templates", () => ({
__esModule: true,
default: () => <div data-testid="policy-templates-stub" />,
}));
vi.mock("./pipeline_flow_builder", () => ({
FlowBuilderPage: () => null,
}));
vi.mock("./policy_info", () => ({
__esModule: true,
default: () => null,
}));
vi.mock("./add_policy_form", () => ({
__esModule: true,
default: () => null,
}));
vi.mock("./guardrail_selection_modal", () => ({
__esModule: true,
default: () => null,
}));
vi.mock("./template_parameter_modal", () => ({
__esModule: true,
default: () => null,
}));
vi.mock("./ai_suggestion_modal", () => ({
__esModule: true,
default: () => null,
}));
vi.mock("./policy_test_panel", () => ({
__esModule: true,
default: () => null,
}));
vi.mock("./add_attachment_form", () => ({
__esModule: true,
default: () => null,
}));
describe("PoliciesPanel attachment delete", () => {
beforeEach(() => {
vi.clearAllMocks();
});
it("should call deletePolicyAttachmentCall after the user confirms delete in the attachment modal", async () => {
const user = userEvent.setup();
renderWithProviders(<PoliciesPanel accessToken="test-token" userRole="Admin" />);
await waitFor(() => {
expect(networkingMocks.getPolicyAttachmentsList).toHaveBeenCalled();
});
await user.click(screen.getByRole("tab", { name: /^attachments$/i }));
await waitFor(() => {
expect(screen.getByText("test-policy")).toBeInTheDocument();
});
await user.click(screen.getByRole("button", { name: /TrashIcon/i }));
const dialog = await screen.findByRole("dialog", {}, { timeout: 5000 });
expect(
within(dialog).getByText(/Are you sure you want to delete this attachment/i),
).toBeInTheDocument();
await user.click(within(dialog).getByRole("button", { name: /^delete$/i }));
await waitFor(() => {
expect(networkingMocks.deletePolicyAttachmentCall).toHaveBeenCalledTimes(1);
expect(networkingMocks.deletePolicyAttachmentCall).toHaveBeenCalledWith("test-token", EXPECTED_ATTACHMENT_ID);
});
});
it("should show mutation pending state while attachment delete is in flight", async () => {
let resolveDelete: (() => void) | undefined;
const deletePromise = new Promise<void>((resolve) => {
resolveDelete = resolve;
});
networkingMocks.deletePolicyAttachmentCall.mockImplementationOnce(() => deletePromise);
const user = userEvent.setup();
renderWithProviders(<PoliciesPanel accessToken="test-token" userRole="Admin" />);
await waitFor(() => {
expect(networkingMocks.getPolicyAttachmentsList).toHaveBeenCalled();
});
await user.click(screen.getByRole("tab", { name: /^attachments$/i }));
await waitFor(() => {
expect(screen.getByText("test-policy")).toBeInTheDocument();
});
await user.click(screen.getByRole("button", { name: /TrashIcon/i }));
const dialog = await screen.findByRole("dialog", {}, { timeout: 5000 });
const deleteButton = within(dialog).getByRole("button", { name: /^delete$/i });
await user.click(deleteButton);
await waitFor(() => {
expect(within(dialog).getByRole("button", { name: /deleting/i })).toBeDisabled();
});
resolveDelete?.();
await waitFor(() => {
expect(screen.queryByRole("dialog")).not.toBeInTheDocument();
});
});
});
@@ -1,8 +1,9 @@
import React, { useState, useEffect, useCallback } from "react";
import { Button, TabGroup, TabList, Tab, TabPanels, TabPanel } from "@tremor/react";
import { Modal, Alert } from "antd";
import { Alert } from "antd";
import MessageManager from "@/components/molecules/message_manager";
import { ExclamationCircleOutlined, InfoCircleOutlined } from "@ant-design/icons";
import { InfoCircleOutlined } from "@ant-design/icons";
import { isAdminRole } from "@/utils/roles";
import PolicyTable from "./policy_table";
import PolicyInfoView from "./policy_info";
@@ -15,11 +16,11 @@ import PolicyTemplates from "./policy_templates";
import GuardrailSelectionModal from "./guardrail_selection_modal";
import TemplateParameterModal from "./template_parameter_modal";
import AiSuggestionModal from "./ai_suggestion_modal";
import { useDeletePolicyAttachment } from "@/hooks/policies/useDeletePolicyAttachment";
import {
getPoliciesList,
deletePolicyCall,
getPolicyAttachmentsList,
deletePolicyAttachmentCall,
getGuardrailsList,
getPolicyInfo,
createPolicyCall,
@@ -57,6 +58,8 @@ const PoliciesPanel: React.FC<PoliciesPanelProps> = ({
const [isDeleting, setIsDeleting] = useState(false);
const [policyToDelete, setPolicyToDelete] = useState<Policy | null>(null);
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
const [attachmentToDelete, setAttachmentToDelete] = useState<PolicyAttachment | null>(null);
const [isDeleteAttachmentModalOpen, setIsDeleteAttachmentModalOpen] = useState(false);
const [isGuardrailSelectionModalOpen, setIsGuardrailSelectionModalOpen] = useState(false);
const [selectedTemplate, setSelectedTemplate] = useState<any>(null);
const [existingGuardrailNames, setExistingGuardrailNames] = useState<Set<string>>(new Set());
@@ -166,24 +169,28 @@ const PoliciesPanel: React.FC<PoliciesPanelProps> = ({
setPolicyToDelete(null);
};
const handleDeleteAttachment = (attachmentId: string) => {
Modal.confirm({
title: "Delete Attachment",
icon: <ExclamationCircleOutlined />,
content: "Are you sure you want to delete this attachment? This action cannot be undone.",
okText: "Delete",
okType: "danger",
cancelText: "Cancel",
onOk: async () => {
if (!accessToken) return;
try {
await deletePolicyAttachmentCall(accessToken, attachmentId);
MessageManager.success("Attachment deleted successfully");
fetchAttachments();
} catch (error) {
console.error("Error deleting attachment:", error);
MessageManager.error("Failed to delete attachment");
}
const deleteAttachmentMutation = useDeletePolicyAttachment({
accessToken,
onSuccess: fetchAttachments,
});
const handleDeleteAttachmentClick = (attachmentId: string) => {
const attachment = attachmentsList.find((a) => a.attachment_id === attachmentId) || null;
setAttachmentToDelete(attachment);
setIsDeleteAttachmentModalOpen(true);
};
const handleAttachmentDeleteCancel = () => {
setIsDeleteAttachmentModalOpen(false);
setAttachmentToDelete(null);
};
const handleAttachmentDeleteConfirm = () => {
if (!attachmentToDelete) return;
deleteAttachmentMutation.mutate(attachmentToDelete.attachment_id, {
onSettled: () => {
setIsDeleteAttachmentModalOpen(false);
setAttachmentToDelete(null);
},
});
};
@@ -579,7 +586,7 @@ const PoliciesPanel: React.FC<PoliciesPanelProps> = ({
<AttachmentTable
attachments={attachmentsList}
isLoading={isAttachmentsLoading}
onDeleteClick={handleDeleteAttachment}
onDeleteClick={handleDeleteAttachmentClick}
isAdmin={isAdmin}
accessToken={accessToken}
/>
@@ -600,6 +607,21 @@ const PoliciesPanel: React.FC<PoliciesPanelProps> = ({
</TabPanels>
</TabGroup>
<DeleteResourceModal
isOpen={isDeleteAttachmentModalOpen}
title="Delete Attachment"
message="Are you sure you want to delete this attachment? This action cannot be undone."
resourceInformationTitle="Attachment Information"
resourceInformation={[
{ label: "Attachment ID", value: attachmentToDelete?.attachment_id, code: true },
{ label: "Policy", value: attachmentToDelete?.policy_name ?? "-" },
{ label: "Scope", value: attachmentToDelete?.scope ?? "-" },
]}
onCancel={handleAttachmentDeleteCancel}
onOk={handleAttachmentDeleteConfirm}
confirmLoading={deleteAttachmentMutation.isPending}
/>
<AiSuggestionModal
visible={isAiSuggestionModalOpen}
onSelectTemplates={(selectedTemplates) => {

Some files were not shown because too many files have changed in this diff Show More