diff --git a/.circleci/config.yml b/.circleci/config.yml index 3949200471..52ab848cf6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2916,90 +2916,6 @@ jobs: - codecov/upload: file: ./coverage.xml - publish_proxy_extras: - docker: - - image: cimg/python:3.12 - working_directory: ~/project/litellm-proxy-extras - environment: - TWINE_USERNAME: __token__ - - steps: - - checkout: - path: ~/project - - - run: - name: Check if litellm-proxy-extras dir or pyproject.toml was modified - command: | - curl -LsSf -o /tmp/uv-install.sh https://astral.sh/uv/0.10.9/install.sh - echo "7fc46e39cb97290b57169c0c813a17970585ac519139f19006453c99b5f2f45f /tmp/uv-install.sh" | sha256sum -c - - env UV_NO_MODIFY_PATH=1 sh /tmp/uv-install.sh - rm -f /tmp/uv-install.sh - echo 'export PATH="$HOME/.local/bin:$PATH"' >> "$BASH_ENV" - echo 'export PATH="$HOME/.local/bin:$PATH"' >> "$BASH_ENV" - export PATH="$HOME/.local/bin:$PATH" - # Get current version from pyproject.toml - CURRENT_VERSION=$(python -c 'import tomllib; from pathlib import Path; data = tomllib.loads(Path("pyproject.toml").read_text()); print(data["project"]["version"])') - - # Get last published version from PyPI - LAST_VERSION=$(curl -s https://pypi.org/pypi/litellm-proxy-extras/json | python -c "import json, sys; print(json.load(sys.stdin)['info']['version'])") - - echo "Current version: $CURRENT_VERSION" - echo "Last published version: $LAST_VERSION" - - # Compare versions using Python's packaging.version - VERSION_COMPARE=$(uv run --with 'packaging==25.0' python -c "from packaging import version; print(1 if version.parse('$CURRENT_VERSION') < version.parse('$LAST_VERSION') else 0)") - - echo "Version compare: $VERSION_COMPARE" - if [ "$VERSION_COMPARE" = "1" ]; then - echo "Error: Current version ($CURRENT_VERSION) is less than last published version ($LAST_VERSION)" - exit 1 - fi - - # If versions are equal or current is greater, compare against the published package contents. - EXTRACTED_DIR=$(uv run --with "litellm-proxy-extras==$LAST_VERSION" python -c 'import importlib.util; from pathlib import Path; spec = importlib.util.find_spec("litellm_proxy_extras"); assert spec is not None and spec.origin is not None, "litellm_proxy_extras not found in uv-run environment"; print(Path(spec.origin).resolve().parent)') - - # Compare contents - if ! diff -r "$EXTRACTED_DIR" ./litellm_proxy_extras; then - if [ "$CURRENT_VERSION" = "$LAST_VERSION" ]; then - echo "Error: Changes detected in litellm-proxy-extras but version was not bumped" - echo "Current version: $CURRENT_VERSION" - echo "Last published version: $LAST_VERSION" - echo "Changes:" - diff -r "$EXTRACTED_DIR" ./litellm_proxy_extras - exit 1 - fi - else - echo "No changes detected in litellm-proxy-extras. Skipping PyPI publish." - circleci step halt - fi - - - run: - name: Get new version - command: | - NEW_VERSION=$(python -c 'import tomllib; from pathlib import Path; data = tomllib.loads(Path("pyproject.toml").read_text()); print(data["project"]["version"])') - echo "export NEW_VERSION=$NEW_VERSION" >> $BASH_ENV - - - run: - name: Check if versions match - command: | - cd ~/project - # Check pyproject.toml - CURRENT_VERSION=$(uv run --with 'packaging==25.0' python -c 'import tomllib; from packaging.requirements import Requirement; from pathlib import Path; data = tomllib.loads(Path("pyproject.toml").read_text()); matches = [spec.version for requirement in data["project"]["optional-dependencies"]["proxy"] for parsed in [Requirement(requirement)] if parsed.name == "litellm-proxy-extras" and parsed.specifier for spec in parsed.specifier if spec.operator == "=="]; print(matches[0] if matches else (_ for _ in ()).throw(SystemExit("Could not find exact litellm-proxy-extras pin in project.optional-dependencies.proxy")))') - if [ "$CURRENT_VERSION" != "$NEW_VERSION" ]; then - echo "Error: Version in pyproject.toml ($CURRENT_VERSION) doesn't match new version ($NEW_VERSION)" - exit 1 - fi - - - run: - name: Publish to PyPI - command: | - echo -e "[pypi]\nusername = $PYPI_PUBLISH_USERNAME\npassword = $PYPI_PUBLISH_PASSWORD" > ~/.pypirc - echo 'export PATH="$HOME/.local/bin:$PATH"' >> "$BASH_ENV" - export PATH="$HOME/.local/bin:$PATH" - rm -rf build dist - uv build - uv tool run --from 'twine==6.2.0' twine upload --verbose dist/* - ui_build: docker: - image: cimg/node:20.19 @@ -3214,60 +3130,6 @@ jobs: - litellm-docker-database.tar.zst - prisma_schema_sync: - machine: - image: ubuntu-2204:2023.10.1 - resource_class: medium - working_directory: ~/project - steps: - - checkout - - setup_google_dns - - attach_workspace: - at: ~/project - - run: - name: Start PostgreSQL Database - command: | - docker run -d \ - --name postgres-db \ - -e POSTGRES_USER=postgres \ - -e POSTGRES_PASSWORD=postgres \ - -e POSTGRES_DB=litellm_schema_sync \ - -p 5432:5432 \ - postgres:14 - - wait_for_service: - url: tcp://localhost:5432 - timeout: "60" - - run: - name: Load Docker Database Image - command: | - zstd -d litellm-docker-database.tar.zst --stdout | docker load - docker images | grep litellm-docker-database - - run: - name: Run schema sync via prisma db push - command: | - docker run -d \ - -p 4000:4000 \ - -e DATABASE_URL="postgresql://postgres:postgres@host.docker.internal:5432/litellm_schema_sync" \ - -e LITELLM_MASTER_KEY="sk-1234" \ - --name schema-sync \ - --add-host=host.docker.internal:host-gateway \ - -v $(pwd)/litellm/proxy/example_config_yaml/simple_config.yaml:/app/config.yaml \ - litellm-docker-database:ci \ - --config /app/config.yaml \ - --port 4000 \ - --use_prisma_db_push - - run: - name: Start outputting logs - command: docker logs -f schema-sync - background: true - - wait_for_service: - url: http://localhost:4000 - timeout: "300" - - run: - name: Stop schema sync container - command: docker stop schema-sync - - test_bad_database_url: machine: image: ubuntu-2204:2023.10.1 @@ -3421,14 +3283,6 @@ workflows: only: - main - /litellm_.*/ - - prisma_schema_sync: - requires: - - build_docker_database_image - filters: - branches: - only: - - main - - /litellm_.*/ - e2e_ui_testing: filters: branches: @@ -3688,9 +3542,3 @@ workflows: only: - main - /litellm_.*/ - - publish_proxy_extras: - filters: - branches: - only: - - main - - /litellm_release_day_.*/ diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 544ace9063..7a4c63b951 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -487,6 +487,7 @@ router_settings: | AZURE_STORAGE_CLIENT_ID | The Application Client ID to use for Authentication to Azure Blob Storage logging | AZURE_STORAGE_CLIENT_SECRET | The Application Client Secret to use for Authentication to Azure Blob Storage logging | AZURE_VECTOR_STORE_COST_PER_GB_PER_DAY | Cost per GB per day for Azure Vector Store service +| BACKGROUND_HEALTH_CHECK_MAX_TOKENS | Optional global default for `max_tokens` on proxy background health checks when a model has no `health_check_max_tokens`. If unset, non-wildcard models default to 1. Applies to wildcard routes when set. Default is unset | BATCH_STATUS_POLL_INTERVAL_SECONDS | Interval in seconds for polling batch status. Default is 3600 (1 hour) | BATCH_STATUS_POLL_MAX_ATTEMPTS | Maximum number of attempts for polling batch status. Default is 24 (for 24 hours) | BEDROCK_MAX_POLICY_SIZE | Maximum size for Bedrock policy. Default is 75 @@ -804,6 +805,8 @@ router_settings: | LITELLM_ASSETS_PATH | Path to directory for UI assets and logos. Used when running with read-only filesystem (e.g., Kubernetes). Default is `/var/lib/litellm/assets` in Docker. | LITELLM_BLOG_POSTS_URL | Custom URL for fetching LiteLLM blog posts JSON. Default is the GitHub main branch URL | LITELLM_CLI_JWT_EXPIRATION_HOURS | Expiration time in hours for CLI-generated JWT tokens. Default is 24 hours +| LITELLM_CORS_ALLOW_CREDENTIALS | Set to `true` to explicitly allow credentials in CORS responses. When not set, credentials are disabled automatically if `LITELLM_CORS_ORIGINS` is `*` (wildcard) to prevent the browser security misconfiguration of reflecting any origin with credentials +| LITELLM_CORS_ORIGINS | Comma-separated list of allowed CORS origins (e.g. `https://app.example.com,https://admin.example.com`). Defaults to `*` (all origins) when not set | LITELLM_DD_AGENT_HOST | Hostname or IP of DataDog agent for LiteLLM-specific logging. When set, logs are sent to agent instead of direct API | LITELLM_DEPLOYMENT_ENVIRONMENT | Environment name for the deployment (e.g., "production", "staging"). Used as a fallback when OTEL_ENVIRONMENT_NAME is not set. Sets the `environment` tag in telemetry data | LITELLM_DETAILED_TIMING | When true, adds detailed per-phase timing headers to responses (`x-litellm-timing-{pre-processing,llm-api,post-processing,message-copy}-ms`). Default is false. See [latency overhead docs](../troubleshoot/latency_overhead.md) @@ -925,6 +928,7 @@ router_settings: | OPENAI_CHATGPT_API_BASE | Alternative to CHATGPT_API_BASE. Base URL for ChatGPT API | OPENAI_FILE_SEARCH_COST_PER_1K_CALLS | Cost per 1000 calls for OpenAI file search. Default is 0.0025 | OPENAI_ORGANIZATION | Organization identifier for OpenAI +| OPENAPI_URL | The path to the OpenAPI JSON endpoint. **By default this is "/openapi.json"** | OPENID_BASE_URL | Base URL for OpenID Connect services | OPENID_CLIENT_ID | Client ID for OpenID Connect authentication | OPENID_CLIENT_SECRET | Client secret for OpenID Connect authentication diff --git a/docs/my-website/docs/proxy/cost_tracking.md b/docs/my-website/docs/proxy/cost_tracking.md index f9e22cfecd..9c43aed1db 100644 --- a/docs/my-website/docs/proxy/cost_tracking.md +++ b/docs/my-website/docs/proxy/cost_tracking.md @@ -14,6 +14,10 @@ Provider-specific cost tracking (e.g., [Vertex AI PayGo / priority pricing](../p [Sync model pricing data from GitHub](./sync_models_github.md) to ensure accurate cost tracking. ::: +:::info Cost does not match your provider bill? +Use the step-by-step workflow in [Debugging a cost discrepancy](../troubleshoot/cost_discrepancy): align time ranges, compare token categories (including cache), then decide whether the gap is ingestion, formula, or model-map pricing. +::: + ### How to Track Spend with LiteLLM **Step 1** diff --git a/docs/my-website/docs/proxy/ui_team_soft_budget_alerts.md b/docs/my-website/docs/proxy/ui_team_soft_budget_alerts.md index 17c42e57c9..413457ccb8 100644 --- a/docs/my-website/docs/proxy/ui_team_soft_budget_alerts.md +++ b/docs/my-website/docs/proxy/ui_team_soft_budget_alerts.md @@ -2,6 +2,16 @@ import Image from '@theme/IdealImage'; # Team Soft Budget Alerts +:::info + +✨ This is an Enterprise feature. Email budget alerts require an enterprise license. + +[Enterprise Pricing](https://www.litellm.ai/#pricing) + +[Get free 7-day trial key](https://www.litellm.ai/enterprise#trial) + +::: + Set a soft budget on a team and get email alerts when spending crosses the threshold — without blocking any requests. ## Overview diff --git a/docs/my-website/docs/troubleshoot/cost_discrepancy.md b/docs/my-website/docs/troubleshoot/cost_discrepancy.md new file mode 100644 index 0000000000..f674ac12ee --- /dev/null +++ b/docs/my-website/docs/troubleshoot/cost_discrepancy.md @@ -0,0 +1,205 @@ +# Debugging a cost discrepancy + +Cost discrepancies between LiteLLM and your provider bill usually come from one of three areas: token ingestion, the cost formula LiteLLM applies, or stale or incorrect pricing in the model map. This page walks through how to tell which case you are in. + +## Step 1: Pick a time range + +Lock down a specific window where the discrepancy is visible. + +- Use at least 7 days of data when you can. +- Prefer a window with stable usage so one-off spikes do not dominate the comparison. +- Set the **same start and end time** on both your provider dashboard and the LiteLLM UI. + +![LiteLLM dashboard date range picker](/img/cost-discrepancy-debug/date-range-picker.png) + +## Step 2: Confirm traffic only goes through LiteLLM + +If any requests hit the provider directly (bypassing LiteLLM), the provider will show higher usage. That is expected, not a LiteLLM bug. + +Before continuing, confirm: + +- All clients use your LiteLLM proxy base URL. +- No SDK or script uses provider API keys against the provider directly for the models you are comparing. +- During the selected period, the models in question are only called via LiteLLM. + +If you are unsure, filter the provider dashboard by the API key or IAM principal LiteLLM uses, rather than comparing to your whole account. + +## Step 3: Compare token categories + +In the LiteLLM UI, open **Model activity** (under Usage analytics) so you can inspect spend and tokens per model. + +![Navigate to Model activity in the LiteLLM UI](/img/cost-discrepancy-debug/go-to-model-activity.png) + +Scroll the **Model** list and select the model you are reconciling with your provider bill. + +![Scroll to your model in the Model activity table](/img/cost-discrepancy-debug/scroll-to-model.png) + +With the same time range on both sides, fill in: + +| Category | LiteLLM | Provider | Delta | +| --- | --- | --- | --- | +| Total requests | — | — | — | +| Input tokens | — | — | — | +| Output tokens | — | — | — | +| Cache read tokens | — | — | — | +| Cache write tokens | — | — | — | + +LiteLLM surfaces per-category token usage for the selected model—for example prompt, completion, and cache-related tokens. + +![LiteLLM usage breakdown by token category](/img/cost-discrepancy-debug/token-categories.png) + +Compare these figures with your provider’s usage view (for example AWS billing tools, Azure Monitor, or the OpenAI usage dashboard) for the same period. + +### Cache token reporting + +- **OpenAI:** Cache read tokens are typically included inside the reported input token count. +- **Anthropic:** Cache read tokens are often reported separately from non-cached input tokens. + +Compare the correct columns on each side so you are not treating “input” differently between dashboards. + +### Why use a 10% threshold? + +Provider dashboards and LiteLLM do not bucket requests on identical timestamps. A call at 11:59 PM can land in different daily totals on each side. Token counts can also differ slightly due to rounding across SDKs and APIs. A delta **under ~10%** is often explained by boundary effects and rounding. A delta **over ~10%** usually means something is miscounted, dropped, or categorized differently. + +## Step 4: Follow the right path + + + Cost discrepancy debugging flowchart + Flowchart branching into Path A (token ingestion) or Path B which splits further into B1 (formula issue) and B2 (model map issue). + + + + + + + + Compare provider vs LiteLLM + + + + + Any category off by > 10%? + requests, input, output, cache tokens + + + YES + + + NO + + + Path A + Token ingestion issue + + + Path B + Quantities match, cost differs + + + + + + + + B1 + B2 + + + Report to LiteLLM team + endpoints + model + screenshots + + + B1 + Fix formula + + + B2 + Fix model map + + + + + if neither path resolves it, + Open a github issue backing up with all your data + + +## Path A: Token quantity mismatch + +If any category is off by more than about 10%, LiteLLM may not be ingesting that category correctly (or the provider dashboard is categorizing tokens differently—recheck Step 3 first). + +**What to send the LiteLLM team:** + +1. Screenshots of both dashboards with the date range visible. +2. Which category is off (input, output, cache reads, cache writes, or request count). +3. Endpoints used (for example `/chat/completions`, `/responses`, `/embeddings`). +4. Model names as sent in the request (for example `anthropic.claude-opus-4-5`, `gpt-4o`). + +### For maintainers debugging ingestion + +1. Start the proxy with verbose logging, for example: + ```bash + litellm --config config.yaml --detailed_debug + ``` +2. Reproduce a single request with the reported endpoint and model. +3. Inspect the raw `usage` object in each streamed chunk (if streaming) or in the final response body. +4. Compare that to the standard logging object (or the UI request log for that call). +5. Any gap between raw provider usage and what LiteLLM logs or aggregates is where ingestion may be wrong. + +## Path B: Quantities match but cost is wrong + +If token and request counts agree within ~10% but dollar amounts differ, focus on how cost is computed. + +### B1: Formula issue + +Manually compute expected cost using the provider’s token breakdown and published rates (per million tokens or per token). + +Add other billed dimensions your provider applies (for example cache creation, audio, or tier surcharges). If your hand calculation matches the provider bill but not LiteLLM, the implementation in LiteLLM for that provider or modality may be wrong. + +### B2: Model map issue + +If the formula structure matches how the provider bills, the values in LiteLLM’s model map may be stale or incorrect. Cross-check: + +- [`model_prices_and_context_window.json`](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) +- The provider’s current public pricing + +Inspect `input_cost_per_token`, `output_cost_per_token`, and any cache-related pricing fields for your exact model id (including provider prefix). + +### For maintainers + +1. Take authoritative token quantities from the user’s provider report. +2. Derive the formula that reproduces the provider’s line item. +3. Diff that against LiteLLM’s cost path for the same provider and response shape. +4. If the formula matches but numbers differ, update pricing in `model_prices_and_context_window.json` (and follow the project’s sync / backup rules for that file). +5. If the formula in code is wrong, fix the calculation and add a regression test using the user’s token breakdown. + +## Still stuck? + +1. Open a GitHub issue on [BerriAI/litellm](https://github.com/BerriAI/litellm) with your Step 3 comparison table, endpoints, and model names. + + +On the issue, it helps to clarify: + +- Reproducible on demand or intermittent? +- Single model or many? +- Steady over time, or starting from a specific release date or config change? + +### For LiteLLM maintainers + +If Path A and Path B do not close the case after triage, **you** should reach out and **schedule a call with the customer** (support or engineering), with the Step 3 table and screenshots—before treating the issue. + +## Checklist + +``` +□ Same time range on both dashboards +□ Confirmed no direct-to-provider traffic for those models +□ Compared: requests, input tokens, output tokens, cache tokens +□ Noted cache reporting differences (OpenAI vs Anthropic, and so on) +□ If > ~10% delta on quantities → Path A: report with screenshots, endpoints, model names +□ If quantities match → Path B: verify formula (B1) and model map pricing (B2) +□ If neither path fits → open a GitHub issue. +``` + +## See also + +- [Spend tracking](../proxy/cost_tracking) +- [Sync model pricing from GitHub](../proxy/sync_models_github) diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 46e392037a..6b97330d40 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -1149,6 +1149,7 @@ const sidebars = { label: "Troubleshooting", items: [ "troubleshoot/ui_issues", + "troubleshoot/cost_discrepancy", "mcp_troubleshoot", { type: "category", diff --git a/docs/my-website/static/img/cost-discrepancy-debug/date-range-picker.png b/docs/my-website/static/img/cost-discrepancy-debug/date-range-picker.png new file mode 100644 index 0000000000..542facc538 Binary files /dev/null and b/docs/my-website/static/img/cost-discrepancy-debug/date-range-picker.png differ diff --git a/docs/my-website/static/img/cost-discrepancy-debug/go-to-model-activity.png b/docs/my-website/static/img/cost-discrepancy-debug/go-to-model-activity.png new file mode 100644 index 0000000000..aae83549d6 Binary files /dev/null and b/docs/my-website/static/img/cost-discrepancy-debug/go-to-model-activity.png differ diff --git a/docs/my-website/static/img/cost-discrepancy-debug/scroll-to-model.png b/docs/my-website/static/img/cost-discrepancy-debug/scroll-to-model.png new file mode 100644 index 0000000000..31693b4083 Binary files /dev/null and b/docs/my-website/static/img/cost-discrepancy-debug/scroll-to-model.png differ diff --git a/docs/my-website/static/img/cost-discrepancy-debug/token-categories.png b/docs/my-website/static/img/cost-discrepancy-debug/token-categories.png new file mode 100644 index 0000000000..bcb843322f Binary files /dev/null and b/docs/my-website/static/img/cost-discrepancy-debug/token-categories.png differ diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260414140000_add_mcp_server_instructions/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260414140000_add_mcp_server_instructions/migration.sql new file mode 100644 index 0000000000..531024c519 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260414140000_add_mcp_server_instructions/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN IF NOT EXISTS "instructions" TEXT; diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260415120000_health_check_latest_per_model_index/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260415120000_health_check_latest_per_model_index/migration.sql new file mode 100644 index 0000000000..189191b901 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260415120000_health_check_latest_per_model_index/migration.sql @@ -0,0 +1,12 @@ +-- CreateIndex (CONCURRENTLY) +-- +-- Disclaimer: +-- - CREATE INDEX CONCURRENTLY cannot run inside a transaction. This migration must stay a +-- single statement so Prisma Migrate on PostgreSQL can apply it outside a transaction. +-- - Builds are slower and use more I/O than a blocking CREATE INDEX; if the build is +-- interrupted, Postgres may leave an INVALID index that must be dropped and recreated. +-- - Do not edit this file after it has been applied to any database: Prisma checksums +-- migrations; add a new migration instead. +-- - Requires PostgreSQL that supports CONCURRENTLY with IF NOT EXISTS (use a new migration +-- without IF NOT EXISTS if you must support older versions). +CREATE INDEX CONCURRENTLY IF NOT EXISTS "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx" ON "LiteLLM_HealthCheckTable"("model_id", "model_name", "checked_at" DESC); diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index fce95465b5..ce3f5f131f 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -289,6 +289,7 @@ model LiteLLM_MCPServerTable { server_name String? alias String? description String? + instructions String? url String? spec_path String? transport String @default("sse") @@ -1045,6 +1046,7 @@ model LiteLLM_HealthCheckTable { @@index([model_name]) @@index([checked_at]) @@index([status]) + @@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx") } // Search Tools table for storing search tool configurations diff --git a/litellm/constants.py b/litellm/constants.py index d0596bed68..348c405506 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -1,6 +1,6 @@ import os import sys -from typing import List, Literal +from typing import List, Literal, Optional from litellm.litellm_core_utils.env_utils import get_env_int @@ -413,7 +413,20 @@ MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = int( ) DEFAULT_MAX_TOKENS_FOR_TRITON = int(os.getenv("DEFAULT_MAX_TOKENS_FOR_TRITON", 2000)) #### Networking settings #### -request_timeout: float = float(os.getenv("REQUEST_TIMEOUT", 6000)) # time in seconds +# Sentinel used when `REQUEST_TIMEOUT` is unset: `litellm.request_timeout` keeps this +# value so longer-running surfaces (Router `timeout or litellm.request_timeout`, +# speech/TTS, responses, vector stores, etc.) get a long HTTP deadline. Chat +# `completion()` maps this sentinel down to 600s when the caller did not set a +# per-request/model timeout—see ``CompletionTimeout.resolve`` in completion_timeout.py. MCP uses +# dedicated timeouts (e.g. `MCP_CLIENT_TIMEOUT`), not `request_timeout`. +DEFAULT_REQUEST_TIMEOUT_SECONDS: float = 6000.0 +# Pair used for default httpx clients when no custom timeout is passed: read/write +# deadline and connect handshake (see ``http_handler`` cached handler paths). +COMPLETION_HTTP_FALLBACK_SECONDS: float = 600.0 +HTTP_HANDLER_CONNECT_TIMEOUT_SECONDS: float = 5.0 +request_timeout: float = float( + os.getenv("REQUEST_TIMEOUT", str(int(DEFAULT_REQUEST_TIMEOUT_SECONDS))) +) DEFAULT_A2A_AGENT_TIMEOUT: float = float( os.getenv("DEFAULT_A2A_AGENT_TIMEOUT", 6000) ) # 10 minutes @@ -1330,6 +1343,22 @@ BATCH_STATUS_POLL_MAX_ATTEMPTS = int( HEALTH_CHECK_TIMEOUT_SECONDS = int( os.getenv("HEALTH_CHECK_TIMEOUT_SECONDS", 60) ) # 60 seconds +_background_health_check_max_tokens_env = os.getenv( + "BACKGROUND_HEALTH_CHECK_MAX_TOKENS" +) +try: + _raw_background_health_check_max_tokens = ( + _background_health_check_max_tokens_env.strip() + if _background_health_check_max_tokens_env is not None + else "" + ) + BACKGROUND_HEALTH_CHECK_MAX_TOKENS: Optional[int] = ( + int(_raw_background_health_check_max_tokens) + if _raw_background_health_check_max_tokens + else None + ) +except (ValueError, TypeError): + BACKGROUND_HEALTH_CHECK_MAX_TOKENS = None LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME = "litellm-internal-health-check" LITTELM_CLI_SERVICE_ACCOUNT_NAME = "litellm-cli" LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME = "litellm_internal_jobs" diff --git a/litellm/experimental_mcp_client/client.py b/litellm/experimental_mcp_client/client.py index 1423617cac..e703a3956b 100644 --- a/litellm/experimental_mcp_client/client.py +++ b/litellm/experimental_mcp_client/client.py @@ -221,6 +221,7 @@ class MCPClient: self.extra_headers: Optional[Dict[str, str]] = extra_headers self.ssl_verify: Optional[VerifyTypes] = ssl_verify self._aws_auth: Optional[httpx.Auth] = aws_auth + self._last_initialize_instructions: Optional[str] = None # handle the basic auth value if provided if auth_value: self.update_auth_value(auth_value) @@ -296,7 +297,12 @@ class MCPClient: session_ctx = ClientSession(read_stream, write_stream) session = await session_ctx.__aenter__() try: - await session.initialize() + init_result = await session.initialize() + self._last_initialize_instructions = None + if init_result is not None: + ins = getattr(init_result, "instructions", None) + if isinstance(ins, str) and ins.strip(): + self._last_initialize_instructions = ins.strip() return await operation(session) finally: try: @@ -315,6 +321,7 @@ class MCPClient: """Open a session, run the provided coroutine, and clean up.""" http_client: Optional[httpx.AsyncClient] = None try: + self._last_initialize_instructions = None transport_ctx, http_client = self._create_transport_context() return await self._execute_session_operation(transport_ctx, operation) except Exception: diff --git a/litellm/litellm_core_utils/completion_timeout.py b/litellm/litellm_core_utils/completion_timeout.py new file mode 100644 index 0000000000..5350d88e59 --- /dev/null +++ b/litellm/litellm_core_utils/completion_timeout.py @@ -0,0 +1,83 @@ +"""Completion HTTP timeout resolution (kept out of ``main.py`` to limit import cycles).""" + +from __future__ import annotations + +from typing import Callable, Optional, Union + +import httpx + +from litellm.constants import ( + COMPLETION_HTTP_FALLBACK_SECONDS, + DEFAULT_REQUEST_TIMEOUT_SECONDS, +) + + +class CompletionTimeout: + """Resolves HTTP timeout for ``completion()`` from model vs global settings.""" + + @staticmethod + def _fallback_when_no_explicit_timeout( + global_timeout: Optional[Union[float, str]], + ) -> float: + """ + Used when ``model_timeout`` and kwargs timeouts are all unset. + + ``global_timeout`` is :attr:`litellm.request_timeout` (numeric / string), not + :class:`httpx.Timeout`. + + If it equals :data:`~litellm.constants.DEFAULT_REQUEST_TIMEOUT_SECONDS` (6000), + return :data:`~litellm.constants.COMPLETION_HTTP_FALLBACK_SECONDS`. Same if + ``None``. Otherwise return ``float(global_timeout)``. + """ + if global_timeout is None: + return COMPLETION_HTTP_FALLBACK_SECONDS + if float(global_timeout) == float(DEFAULT_REQUEST_TIMEOUT_SECONDS): + return COMPLETION_HTTP_FALLBACK_SECONDS + return float(global_timeout) + + @staticmethod + def resolve( + model_timeout: Optional[Union[float, str, httpx.Timeout]], + kwargs: dict, + custom_llm_provider: str, + *, + global_timeout: Optional[Union[float, str]], + supports_httpx_timeout: Callable[[str], bool], + ) -> Union[float, httpx.Timeout]: + """ + Resolution order (first non-None wins): + + 1. ``model_timeout`` (call argument / merged ``litellm_params``) + 2. ``kwargs["timeout"]`` + 3. ``kwargs["request_timeout"]`` + 4. Fallback from ``global_timeout`` (:attr:`litellm.request_timeout`) — if it is + the package default (6000), use 600 instead. + + Coerce :class:`httpx.Timeout` when the provider does not support it. + Explicit ``6000`` on the model or in kwargs is kept as ``6000``. + """ + resolved: Union[float, str, httpx.Timeout] + if model_timeout is not None: + resolved = model_timeout + elif kwargs.get("timeout") is not None: + resolved = kwargs["timeout"] + elif kwargs.get("request_timeout") is not None: + resolved = kwargs["request_timeout"] + else: + resolved = CompletionTimeout._fallback_when_no_explicit_timeout( + global_timeout + ) + + if isinstance(resolved, httpx.Timeout) and not supports_httpx_timeout( + custom_llm_provider + ): + read_timeout = resolved.read + resolved = ( + float(read_timeout) + if read_timeout is not None + else COMPLETION_HTTP_FALLBACK_SECONDS + ) # default 10 min timeout + elif not isinstance(resolved, httpx.Timeout): + resolved = float(resolved) # type: ignore + + return resolved diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 87df171d65..fd14f55add 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -354,9 +354,9 @@ class Logging(LiteLLMLoggingBaseClass): ) self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response - self.sync_streaming_chunks: List[ - Any - ] = [] # for generating complete stream response + self.sync_streaming_chunks: List[Any] = ( + [] + ) # for generating complete stream response self.log_raw_request_response = log_raw_request_response # Initialize dynamic callbacks @@ -811,9 +811,9 @@ class Logging(LiteLLMLoggingBaseClass): prompt_spec=prompt_spec, dynamic_callback_params=dynamic_callback_params, ): - self.model_call_details[ - "prompt_integration" - ] = logger.__class__.__name__ + self.model_call_details["prompt_integration"] = ( + logger.__class__.__name__ + ) return logger except Exception: # If check fails, continue to next logger @@ -881,9 +881,9 @@ class Logging(LiteLLMLoggingBaseClass): if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook( non_default_params ): - self.model_call_details[ - "prompt_integration" - ] = anthropic_cache_control_logger.__class__.__name__ + self.model_call_details["prompt_integration"] = ( + anthropic_cache_control_logger.__class__.__name__ + ) return anthropic_cache_control_logger ######################################################### @@ -895,9 +895,9 @@ class Logging(LiteLLMLoggingBaseClass): internal_usage_cache=None, llm_router=None, ) - self.model_call_details[ - "prompt_integration" - ] = vector_store_custom_logger.__class__.__name__ + self.model_call_details["prompt_integration"] = ( + vector_store_custom_logger.__class__.__name__ + ) # Add to global callbacks so post-call hooks are invoked if ( vector_store_custom_logger @@ -957,9 +957,9 @@ class Logging(LiteLLMLoggingBaseClass): model ): # if model name was changes pre-call, overwrite the initial model call name with the new one self.model_call_details["model"] = model - self.model_call_details["litellm_params"][ - "api_base" - ] = self._get_masked_api_base(additional_args.get("api_base", "")) + self.model_call_details["litellm_params"]["api_base"] = ( + self._get_masked_api_base(additional_args.get("api_base", "")) + ) def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 # Log the exact input to the LLM API @@ -988,10 +988,10 @@ class Logging(LiteLLMLoggingBaseClass): try: # [Non-blocking Extra Debug Information in metadata] if turn_off_message_logging is True: - _metadata[ - "raw_request" - ] = "redacted by litellm. \ + _metadata["raw_request"] = ( + "redacted by litellm. \ 'litellm.turn_off_message_logging=True'" + ) else: curl_command = self._get_request_curl_command( api_base=additional_args.get("api_base", ""), @@ -1002,34 +1002,34 @@ class Logging(LiteLLMLoggingBaseClass): _metadata["raw_request"] = str(curl_command) # split up, so it's easier to parse in the UI - self.model_call_details[ - "raw_request_typed_dict" - ] = RawRequestTypedDict( - raw_request_api_base=str( - additional_args.get("api_base") or "" - ), - raw_request_body=self._get_raw_request_body( - additional_args.get("complete_input_dict", {}) - ), - # NOTE: setting ignore_sensitive_headers to True will cause - # the Authorization header to be leaked when calls to the health - # endpoint are made and fail. - raw_request_headers=self._get_masked_headers( - additional_args.get("headers", {}) or {}, - ), - error=None, + self.model_call_details["raw_request_typed_dict"] = ( + RawRequestTypedDict( + raw_request_api_base=str( + additional_args.get("api_base") or "" + ), + raw_request_body=self._get_raw_request_body( + additional_args.get("complete_input_dict", {}) + ), + # NOTE: setting ignore_sensitive_headers to True will cause + # the Authorization header to be leaked when calls to the health + # endpoint are made and fail. + raw_request_headers=self._get_masked_headers( + additional_args.get("headers", {}) or {}, + ), + error=None, + ) ) except Exception as e: - self.model_call_details[ - "raw_request_typed_dict" - ] = RawRequestTypedDict( - error=str(e), + self.model_call_details["raw_request_typed_dict"] = ( + RawRequestTypedDict( + error=str(e), + ) ) - _metadata[ - "raw_request" - ] = "Unable to Log \ + _metadata["raw_request"] = ( + "Unable to Log \ raw request: {}".format( - str(e) + str(e) + ) ) if getattr(self, "logger_fn", None) and callable(self.logger_fn): try: @@ -1330,13 +1330,13 @@ class Logging(LiteLLMLoggingBaseClass): for callback in callbacks: try: if isinstance(callback, CustomLogger): - response: Optional[ - MCPPostCallResponseObject - ] = await callback.async_post_mcp_tool_call_hook( - kwargs=kwargs, - response_obj=post_mcp_tool_call_response_obj, - start_time=start_time, - end_time=end_time, + response: Optional[MCPPostCallResponseObject] = ( + await callback.async_post_mcp_tool_call_hook( + kwargs=kwargs, + response_obj=post_mcp_tool_call_response_obj, + start_time=start_time, + end_time=end_time, + ) ) ###################################################################### # if any of the callbacks modify the response, use the modified response @@ -1543,9 +1543,9 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details[ - "response_cost_failure_debug_information" - ] = debug_info + self.model_call_details["response_cost_failure_debug_information"] = ( + debug_info + ) return None try: @@ -1571,9 +1571,9 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details[ - "response_cost_failure_debug_information" - ] = debug_info + self.model_call_details["response_cost_failure_debug_information"] = ( + debug_info + ) return None @@ -1722,9 +1722,9 @@ class Logging(LiteLLMLoggingBaseClass): self.model_call_details["litellm_params"].setdefault("metadata", {}) if self.model_call_details["litellm_params"]["metadata"] is None: self.model_call_details["litellm_params"]["metadata"] = {} - self.model_call_details["litellm_params"]["metadata"][ - "hidden_params" - ] = getattr(logging_result, "_hidden_params", {}) + self.model_call_details["litellm_params"]["metadata"]["hidden_params"] = ( + getattr(logging_result, "_hidden_params", {}) + ) def _process_hidden_params_and_response_cost( self, @@ -1753,9 +1753,9 @@ class Logging(LiteLLMLoggingBaseClass): result=logging_result ) - self.model_call_details[ - "standard_logging_object" - ] = self._build_standard_logging_payload(logging_result, start_time, end_time) + self.model_call_details["standard_logging_object"] = ( + self._build_standard_logging_payload(logging_result, start_time, end_time) + ) if ( standard_logging_payload := self.model_call_details.get( @@ -1833,9 +1833,9 @@ class Logging(LiteLLMLoggingBaseClass): end_time = datetime.datetime.now() if self.completion_start_time is None: self.completion_start_time = end_time - self.model_call_details[ - "completion_start_time" - ] = self.completion_start_time + self.model_call_details["completion_start_time"] = ( + self.completion_start_time + ) self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time @@ -1872,10 +1872,10 @@ class Logging(LiteLLMLoggingBaseClass): end_time=end_time, ) elif isinstance(result, dict) or isinstance(result, list): - self.model_call_details[ - "standard_logging_object" - ] = self._build_standard_logging_payload( - result, start_time, end_time + self.model_call_details["standard_logging_object"] = ( + self._build_standard_logging_payload( + result, start_time, end_time + ) ) if ( standard_logging_payload := self.model_call_details.get( @@ -1884,9 +1884,9 @@ class Logging(LiteLLMLoggingBaseClass): ) is not None: emit_standard_logging_payload(standard_logging_payload) elif standard_logging_object is not None: - self.model_call_details[ - "standard_logging_object" - ] = standard_logging_object + self.model_call_details["standard_logging_object"] = ( + standard_logging_object + ) else: self.model_call_details["response_cost"] = None @@ -2044,20 +2044,20 @@ class Logging(LiteLLMLoggingBaseClass): verbose_logger.debug( "Logging Details LiteLLM-Success Call streaming complete" ) - self.model_call_details[ - "complete_streaming_response" - ] = complete_streaming_response - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator(result=complete_streaming_response) + self.model_call_details["complete_streaming_response"] = ( + complete_streaming_response + ) + self.model_call_details["response_cost"] = ( + self._response_cost_calculator(result=complete_streaming_response) + ) self._merge_hidden_params_from_response_into_metadata( complete_streaming_response ) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = self._build_standard_logging_payload( - complete_streaming_response, start_time, end_time + self.model_call_details["standard_logging_object"] = ( + self._build_standard_logging_payload( + complete_streaming_response, start_time, end_time + ) ) if ( standard_logging_payload := self.model_call_details.get( @@ -2391,10 +2391,10 @@ class Logging(LiteLLMLoggingBaseClass): ) else: if self.stream and complete_streaming_response: - self.model_call_details[ - "complete_response" - ] = self.model_call_details.get( - "complete_streaming_response", {} + self.model_call_details["complete_response"] = ( + self.model_call_details.get( + "complete_streaming_response", {} + ) ) result = self.model_call_details["complete_response"] openMeterLogger.log_success_event( @@ -2418,10 +2418,10 @@ class Logging(LiteLLMLoggingBaseClass): ) else: if self.stream and complete_streaming_response: - self.model_call_details[ - "complete_response" - ] = self.model_call_details.get( - "complete_streaming_response", {} + self.model_call_details["complete_response"] = ( + self.model_call_details.get( + "complete_streaming_response", {} + ) ) result = self.model_call_details["complete_response"] @@ -2560,9 +2560,9 @@ class Logging(LiteLLMLoggingBaseClass): if complete_streaming_response is not None: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details[ - "async_complete_streaming_response" - ] = complete_streaming_response + self.model_call_details["async_complete_streaming_response"] = ( + complete_streaming_response + ) try: if self.model_call_details.get("cache_hit", False) is True: @@ -2573,10 +2573,10 @@ class Logging(LiteLLMLoggingBaseClass): model_call_details=self.model_call_details ) # base_model defaults to None if not set on model_info - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator( - result=complete_streaming_response + self.model_call_details["response_cost"] = ( + self._response_cost_calculator( + result=complete_streaming_response + ) ) verbose_logger.debug( @@ -2593,10 +2593,10 @@ class Logging(LiteLLMLoggingBaseClass): ) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = self._build_standard_logging_payload( - complete_streaming_response, start_time, end_time + self.model_call_details["standard_logging_object"] = ( + self._build_standard_logging_payload( + complete_streaming_response, start_time, end_time + ) ) # print standard logging payload @@ -2623,9 +2623,9 @@ class Logging(LiteLLMLoggingBaseClass): # _success_handler_helper_fn if self.model_call_details.get("standard_logging_object") is None: ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = self._build_standard_logging_payload(result, start_time, end_time) + self.model_call_details["standard_logging_object"] = ( + self._build_standard_logging_payload(result, start_time, end_time) + ) # print standard logging payload if ( @@ -2872,18 +2872,18 @@ class Logging(LiteLLMLoggingBaseClass): ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj={}, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="failure", - error_str=_redact_string(str(exception)), - original_exception=exception, - standard_built_in_tools_params=self.standard_built_in_tools_params, + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj={}, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="failure", + error_str=_redact_string(str(exception)), + original_exception=exception, + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) ) return start_time, end_time @@ -3853,9 +3853,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 service_name=arize_config.project_name, ) - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}" + os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( + f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}" + ) for callback in _in_memory_loggers: if ( isinstance(callback, ArizeLogger) @@ -3881,13 +3881,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") # Add openinference.project.name attribute if existing_attrs: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}" + os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( + f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}" + ) else: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"openinference.project.name={arize_phoenix_config.project_name}" + os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( + f"openinference.project.name={arize_phoenix_config.project_name}" + ) # Set Phoenix project name from environment variable phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None) @@ -3895,19 +3895,19 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") # Add openinference.project.name attribute if existing_attrs: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"{existing_attrs},openinference.project.name={phoenix_project_name}" + os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( + f"{existing_attrs},openinference.project.name={phoenix_project_name}" + ) else: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"openinference.project.name={phoenix_project_name}" + os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( + f"openinference.project.name={phoenix_project_name}" + ) # auth can be disabled on local deployments of arize phoenix if arize_phoenix_config.otlp_auth_headers is not None: - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = arize_phoenix_config.otlp_auth_headers + os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( + arize_phoenix_config.otlp_auth_headers + ) for callback in _in_memory_loggers: if ( @@ -4094,9 +4094,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 exporter="otlp_http", endpoint="https://langtrace.ai/api/trace", ) - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" + os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( + f"api_key={os.getenv('LANGTRACE_API_KEY')}" + ) for callback in _in_memory_loggers: if ( isinstance(callback, OpenTelemetry) @@ -4991,16 +4991,22 @@ class StandardLoggingPayloadSetup: additional_logging_headers: StandardLoggingAdditionalHeaders = {} + # Populate well-known typed fields with int/str coercion where needed + typed_keys: dict = {} for key in StandardLoggingAdditionalHeaders.__annotations__.keys(): - _key = key.lower() - _key = _key.replace("_", "-") + _key = key.lower().replace("_", "-") + typed_keys[_key] = key if _key in additiona_headers: try: additional_logging_headers[key] = int(additiona_headers[_key]) # type: ignore except (ValueError, TypeError): - verbose_logger.debug( - f"Could not convert {additiona_headers[_key]} to int for key {key}." - ) + additional_logging_headers[key] = additiona_headers[_key] # type: ignore + + # Preserve all remaining headers verbatim (e.g. llm_provider-x-request-id) + for k, v in additiona_headers.items(): + if k.lower() not in typed_keys: + additional_logging_headers[k] = v # type: ignore + return additional_logging_headers @staticmethod @@ -5022,10 +5028,10 @@ class StandardLoggingPayloadSetup: for key in StandardLoggingHiddenParams.__annotations__.keys(): if key in hidden_params: if key == "additional_headers": - clean_hidden_params[ - "additional_headers" - ] = StandardLoggingPayloadSetup.get_additional_headers( - hidden_params[key] + clean_hidden_params["additional_headers"] = ( + StandardLoggingPayloadSetup.get_additional_headers( + hidden_params[key] + ) ) else: clean_hidden_params[key] = hidden_params[key] # type: ignore @@ -5666,9 +5672,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): ): for k, v in metadata["user_api_key_metadata"].items(): if k == "logging": # prevent logging user logging keys - cleaned_user_api_key_metadata[ - k - ] = "scrubbed_by_litellm_for_sensitive_keys" + cleaned_user_api_key_metadata[k] = ( + "scrubbed_by_litellm_for_sensitive_keys" + ) else: cleaned_user_api_key_metadata[k] = v diff --git a/litellm/litellm_core_utils/model_param_helper.py b/litellm/litellm_core_utils/model_param_helper.py index 66b174feac..b4fa5cb60a 100644 --- a/litellm/litellm_core_utils/model_param_helper.py +++ b/litellm/litellm_core_utils/model_param_helper.py @@ -11,6 +11,10 @@ from openai.types.completion_create_params import ( CompletionCreateParamsStreaming as TextCompletionCreateParamsStreaming, ) from openai.types.embedding_create_params import EmbeddingCreateParams +from openai.types.responses.response_create_params import ( + ResponseCreateParamsNonStreaming, + ResponseCreateParamsStreaming, +) from litellm._logging import verbose_logger from litellm.types.rerank import RerankRequest @@ -65,6 +69,9 @@ class ModelParamHelper: ModelParamHelper._get_litellm_supported_transcription_kwargs() ) rerank_kwargs = ModelParamHelper._get_litellm_supported_rerank_kwargs() + responses_api_kwargs = ( + ModelParamHelper._get_litellm_supported_responses_api_kwargs() + ) exclude_kwargs = ModelParamHelper._get_exclude_kwargs() combined_kwargs = chat_completion_kwargs.union( @@ -72,6 +79,7 @@ class ModelParamHelper: embedding_kwargs, transcription_kwargs, rerank_kwargs, + responses_api_kwargs, ) combined_kwargs = combined_kwargs.difference(exclude_kwargs) return combined_kwargs @@ -93,9 +101,9 @@ class ModelParamHelper: streaming_params: Set[str] = set( getattr(CompletionCreateParamsStreaming, "__annotations__", {}).keys() ) - litellm_provider_specific_params: Set[ - str - ] = ModelParamHelper.get_litellm_provider_specific_params_for_chat_params() + litellm_provider_specific_params: Set[str] = ( + ModelParamHelper.get_litellm_provider_specific_params_for_chat_params() + ) all_chat_completion_kwargs: Set[str] = non_streaming_params.union( streaming_params ).union(litellm_provider_specific_params) @@ -167,6 +175,21 @@ class ModelParamHelper: verbose_logger.debug("Error getting transcription kwargs %s", str(e)) return set() + @staticmethod + def _get_litellm_supported_responses_api_kwargs() -> Set[str]: + """ + Get the litellm supported responses API kwargs + + This follows the OpenAI API Spec + """ + non_streaming_params: Set[str] = set( + getattr(ResponseCreateParamsNonStreaming, "__annotations__", {}).keys() + ) + streaming_params: Set[str] = set( + getattr(ResponseCreateParamsStreaming, "__annotations__", {}).keys() + ) + return non_streaming_params.union(streaming_params) + @staticmethod def _get_exclude_kwargs() -> Set[str]: """ diff --git a/litellm/llms/anthropic/common_utils.py b/litellm/llms/anthropic/common_utils.py index a0da14bcc2..3be5a0c816 100644 --- a/litellm/llms/anthropic/common_utils.py +++ b/litellm/llms/anthropic/common_utils.py @@ -2,6 +2,7 @@ This file contains common utils for anthropic calls. """ +import copy from typing import Any, Dict, List, Optional, Union import httpx @@ -736,6 +737,69 @@ def strip_advisor_blocks_from_messages( return messages +def is_anthropic_invalid_thinking_signature_error(error_text: str) -> bool: + """ + Detect Anthropic 400 when encrypted thinking signatures in history do not match + the current deployment (e.g. user rotated API key or switched model endpoint). + + Example API message: + messages.N.content.M: Invalid `signature` in `thinking` block + """ + if not error_text: + return False + lower = error_text.lower() + return ( + "invalid" in lower + and "signature" in lower + and "thinking" in lower + and "block" in lower + ) + + +def strip_thinking_blocks_from_anthropic_messages(messages: List[Any]) -> List[Any]: + """ + Return a new message list with thinking / redacted_thinking content blocks removed + from each message. Used to recover from invalid thinking signatures on retry. + + Messages whose content is a list and becomes empty after stripping are omitted, + since Anthropic rejects empty content arrays. + """ + out: List[Any] = [] + for m in messages: + if not isinstance(m, dict): + out.append(m) + continue + mm = copy.deepcopy(m) + content = mm.get("content") + if isinstance(content, list): + filtered = [ + b + for b in content + if not ( + isinstance(b, dict) + and b.get("type") in ("thinking", "redacted_thinking") + ) + ] + if not filtered: + continue + mm["content"] = filtered + out.append(mm) + return out + + +def strip_thinking_blocks_from_anthropic_messages_request_dict( + data: Dict[str, Any], +) -> None: + """ + Mutate an Anthropic Messages-style request dict: strip thinking blocks from + ``messages`` and remove the top-level ``thinking`` extended-thinking param. + """ + msgs = data.get("messages") + if isinstance(msgs, list): + data["messages"] = strip_thinking_blocks_from_anthropic_messages(msgs) + data.pop("thinking", None) + + def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict: openai_headers = {} if "anthropic-ratelimit-requests-limit" in headers: diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/streaming_iterator.py b/litellm/llms/anthropic/experimental_pass_through/messages/streaming_iterator.py index 6cab38932a..914855a2bd 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/streaming_iterator.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/streaming_iterator.py @@ -27,6 +27,7 @@ class BaseAnthropicMessagesStreamingIterator: self.litellm_logging_obj = litellm_logging_obj self.request_body = request_body self.start_time = datetime.now() + self.completion_start_time: datetime | None = None async def _handle_streaming_logging(self, collected_chunks: List[bytes]): """Handle the logging after all chunks have been collected.""" @@ -35,6 +36,15 @@ class BaseAnthropicMessagesStreamingIterator: ) end_time = datetime.now() + # Set completion_start_time so TTFT is calculated from the first + # chunk rather than falling back to end_time in async_success_handler. + if self.completion_start_time is not None: + self.litellm_logging_obj.completion_start_time = ( + self.completion_start_time + ) + self.litellm_logging_obj.model_call_details[ + "completion_start_time" + ] = self.completion_start_time asyncio.create_task( PassThroughStreamingHandler._route_streaming_logging_to_handler( litellm_logging_obj=self.litellm_logging_obj, @@ -100,6 +110,8 @@ class BaseAnthropicMessagesStreamingIterator: collected_chunks = [] async for chunk in completion_stream: + if self.completion_start_time is None: + self.completion_start_time = datetime.now() encoded_chunk = self._convert_chunk_to_sse_format(chunk) collected_chunks.append(encoded_chunk) yield encoded_chunk diff --git a/litellm/llms/azure/passthrough/transformation.py b/litellm/llms/azure/passthrough/transformation.py index 4e9de4b314..9b1d95e531 100644 --- a/litellm/llms/azure/passthrough/transformation.py +++ b/litellm/llms/azure/passthrough/transformation.py @@ -1,7 +1,9 @@ from typing import TYPE_CHECKING, List, Optional, Tuple import httpx +from httpx import Response +from litellm.litellm_core_utils.litellm_logging import Logging from litellm.llms.azure.common_utils import BaseAzureLLM from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig from litellm.secret_managers.main import get_secret_str @@ -11,6 +13,8 @@ from litellm.types.router import GenericLiteLLMParams if TYPE_CHECKING: from httpx import URL + from litellm.types.utils import CostResponseTypes + class AzurePassthroughConfig(BasePassthroughConfig): def is_streaming_request(self, endpoint: str, request_data: dict) -> bool: @@ -83,3 +87,36 @@ class AzurePassthroughConfig(BasePassthroughConfig): self, api_key: Optional[str] = None, api_base: Optional[str] = None ) -> List[str]: return super().get_models(api_key, api_base) + + def logging_non_streaming_response( + self, + model: str, + custom_llm_provider: str, + httpx_response: Response, + request_data: dict, + logging_obj: Logging, + endpoint: str, + ) -> Optional["CostResponseTypes"]: + from litellm import encoding + from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig + from litellm.types.utils import ModelResponse + + if "chat/completions" not in endpoint: + return None + + openai_chat_config = OpenAIGPTConfig() + + litellm_model_response: ModelResponse = openai_chat_config.transform_response( + model=model, + messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}], + raw_response=httpx_response, + model_response=ModelResponse(), + logging_obj=logging_obj, + optional_params={}, + litellm_params={}, + api_key="", + request_data=request_data, + encoding=encoding, + ) + + return litellm_model_response diff --git a/litellm/llms/base_llm/anthropic_messages/transformation.py b/litellm/llms/base_llm/anthropic_messages/transformation.py index fdad1633e8..49aa563781 100644 --- a/litellm/llms/base_llm/anthropic_messages/transformation.py +++ b/litellm/llms/base_llm/anthropic_messages/transformation.py @@ -120,3 +120,46 @@ class BaseAnthropicMessagesConfig(ABC): return BaseLLMException( message=error_message, status_code=status_code, headers=headers ) + + @property + def max_retry_on_anthropic_messages_http_error(self) -> int: + """ + Max HTTP attempts for /v1/messages when the handler may mutate the body and + retry (e.g. strip invalid encrypted thinking signatures after a deployment or + credential change). + """ + return 2 + + def should_retry_anthropic_messages_on_http_error( + self, e: httpx.HTTPStatusError, litellm_params: dict + ) -> bool: + """ + When True, async_anthropic_messages_handler will transform the request body + and issue one more attempt (bounded by max_retry_on_anthropic_messages_http_error). + """ + from litellm.llms.anthropic.common_utils import ( + is_anthropic_invalid_thinking_signature_error, + ) + + return ( + e.response.status_code == 400 + and is_anthropic_invalid_thinking_signature_error(e.response.text) + ) + + def transform_anthropic_messages_request_on_http_error( + self, e: httpx.HTTPStatusError, request_data: dict + ) -> dict: + """ + Mutates request_data in place when retrying after a recoverable HTTP error. + """ + from litellm.llms.anthropic.common_utils import ( + is_anthropic_invalid_thinking_signature_error, + strip_thinking_blocks_from_anthropic_messages_request_dict, + ) + + if ( + e.response.status_code == 400 + and is_anthropic_invalid_thinking_signature_error(e.response.text) + ): + strip_thinking_blocks_from_anthropic_messages_request_dict(request_data) + return request_data diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 4e71c9584a..5abd74df87 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -1003,7 +1003,7 @@ class AmazonConverseConfig(BaseConfig): description=description, ) optional_params["outputConfig"] = output_config - else: + elif json_schema is not None: # Fallback: translate to a synthetic tool call # https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode _tool = self._create_json_tool_call_for_response_format( @@ -1025,6 +1025,12 @@ class AmazonConverseConfig(BaseConfig): ) if non_default_params.get("stream", False) is True: optional_params["fake_stream"] = True + # else: response_format=json_object with no schema. + # Don't inject the synthetic json_tool_call tool here. When no + # schema is given, _create_json_tool_call_for_response_format + # produces an empty schema (properties: {}), and the model + # returns {} instead of the requested JSON. The model already + # returns JSON when the prompt asks for it. optional_params["json_mode"] = True return optional_params @@ -2030,6 +2036,12 @@ class AmazonConverseConfig(BaseConfig): _message = Message(**chat_completion_message) initial_finish_reason = map_finish_reason(completion_response["stopReason"]) + # When json_mode filtered out all synthetic tool calls the response + # is plain content, not a pending tool invocation. Fix finish_reason + # so callers (e.g. OpenAI SDK) don't misinterpret it. + if json_mode and not filtered_tools and tools: + initial_finish_reason = "stop" + ( returned_message, returned_finish_reason, diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 47b0699096..489a56daf8 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -30,7 +30,9 @@ from litellm.constants import ( AIOHTTP_KEEPALIVE_TIMEOUT, AIOHTTP_NEEDS_CLEANUP_CLOSED, AIOHTTP_TTL_DNS_CACHE, + COMPLETION_HTTP_FALLBACK_SECONDS, DEFAULT_SSL_CIPHERS, + HTTP_HANDLER_CONNECT_TIMEOUT_SECONDS, ) from litellm.litellm_core_utils.logging_utils import track_llm_api_timing from litellm.types.llms.custom_http import * @@ -70,7 +72,10 @@ def get_default_headers() -> dict: headers = get_default_headers() # https://www.python-httpx.org/advanced/timeouts -_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) +_DEFAULT_TIMEOUT = httpx.Timeout( + timeout=COMPLETION_HTTP_FALLBACK_SECONDS, + connect=HTTP_HANDLER_CONNECT_TIMEOUT_SECONDS, +) def _prepare_request_data_and_content( @@ -1258,7 +1263,7 @@ def get_async_httpx_client( _new_client = AsyncHTTPHandler(**handler_params) else: _new_client = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0), + timeout=_DEFAULT_TIMEOUT, shared_session=shared_session, ) @@ -1307,7 +1312,7 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler: } _new_client = HTTPHandler(**handler_params) else: - _new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + _new_client = HTTPHandler(timeout=_DEFAULT_TIMEOUT) cache.set_cache( key=_cache_key_name, diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index d8ef7e7440..ea0c05e765 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -1816,6 +1816,73 @@ class BaseLLMHTTPHandler: logging_obj=logging_obj, ) + async def _async_post_anthropic_messages_with_http_error_retry( + self, + async_httpx_client: AsyncHTTPHandler, + request_url: str, + headers: dict, + signed_json_body: Optional[bytes], + request_body: dict, + stream: bool, + logging_obj: LiteLLMLoggingObj, + provider_config: BaseAnthropicMessagesConfig, + litellm_params: GenericLiteLLMParams, + api_key: Optional[str], + model: str, + ) -> httpx.Response: + max_attempts = max( + provider_config.max_retry_on_anthropic_messages_http_error, 1 + ) + litellm_params_dict = dict(litellm_params) + optional_params_dict = dict(litellm_params) + for attempt_idx in range(max_attempts): + try: + response = await async_httpx_client.post( + url=request_url, + headers=headers, + data=signed_json_body or json.dumps(request_body), + stream=stream or False, + logging_obj=logging_obj, + ) + response.raise_for_status() + return response + except httpx.HTTPStatusError as e: + hit_max_attempt = attempt_idx + 1 == max_attempts + should_retry = ( + provider_config.should_retry_anthropic_messages_on_http_error( + e=e, litellm_params=litellm_params_dict + ) + ) + if should_retry and not hit_max_attempt: + verbose_logger.debug( + "Anthropic /v1/messages: invalid thinking signature; " + "stripping thinking blocks and retrying (attempt %s/%s).", + attempt_idx + 2, + max_attempts, + ) + provider_config.transform_anthropic_messages_request_on_http_error( + e=e, request_data=request_body + ) + headers, signed_json_body = provider_config.sign_request( + headers=headers, + optional_params=optional_params_dict, + request_data=request_body, + api_base=request_url, + api_key=api_key, + stream=stream, + fake_stream=False, + model=model, + ) + logging_obj.model_call_details.update(request_body) + continue + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + + raise RuntimeError( + "unreachable: anthropic messages HTTP retry loop exited without return" + ) + async def async_anthropic_messages_handler( self, model: str, @@ -1955,19 +2022,19 @@ class BaseLLMHTTPHandler: }, ) - try: - response = await async_httpx_client.post( - url=request_url, - headers=headers, - data=signed_json_body or json.dumps(request_body), - stream=stream or False, - logging_obj=logging_obj, - ) - response.raise_for_status() - except Exception as e: - raise self._handle_error( - e=e, provider_config=anthropic_messages_provider_config - ) + response = await self._async_post_anthropic_messages_with_http_error_retry( + async_httpx_client=async_httpx_client, + request_url=request_url, + headers=headers, + signed_json_body=signed_json_body, + request_body=request_body, + stream=stream or False, + logging_obj=logging_obj, + provider_config=anthropic_messages_provider_config, + litellm_params=litellm_params, + api_key=api_key, + model=model, + ) # used for logging + cost tracking logging_obj.model_call_details["httpx_response"] = response @@ -4496,9 +4563,9 @@ class BaseLLMHTTPHandler: # Second: Execute agentic loop # Add custom_llm_provider to kwargs so the agentic loop can reconstruct the full model name kwargs_with_provider = kwargs.copy() if kwargs else {} - kwargs_with_provider[ - "custom_llm_provider" - ] = custom_llm_provider + kwargs_with_provider["custom_llm_provider"] = ( + custom_llm_provider + ) agentic_response = await callback.async_run_agentic_loop( tools=tool_calls, model=model, @@ -4614,9 +4681,9 @@ class BaseLLMHTTPHandler: # Second: Execute agentic loop # Add custom_llm_provider to kwargs so the agentic loop can reconstruct the full model name kwargs_with_provider = kwargs.copy() if kwargs else {} - kwargs_with_provider[ - "custom_llm_provider" - ] = custom_llm_provider + kwargs_with_provider["custom_llm_provider"] = ( + custom_llm_provider + ) agentic_response = ( await callback.async_run_chat_completion_agentic_loop( tools=tool_calls, @@ -5110,7 +5177,10 @@ class BaseLLMHTTPHandler: _is_async: bool = False, fake_stream: bool = False, litellm_metadata: Optional[Dict[str, Any]] = None, - ) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse],]: + ) -> Union[ + ImageResponse, + Coroutine[Any, Any, ImageResponse], + ]: """ Handles image edit requests. @@ -5322,7 +5392,10 @@ class BaseLLMHTTPHandler: fake_stream: bool = False, litellm_metadata: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None, - ) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse],]: + ) -> Union[ + ImageResponse, + Coroutine[Any, Any, ImageResponse], + ]: """ Handles image generation requests. When _is_async=True, returns a coroutine instead of making the call directly. @@ -5562,7 +5635,10 @@ class BaseLLMHTTPHandler: fake_stream: bool = False, litellm_metadata: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None, - ) -> Union[VideoObject, Coroutine[Any, Any, VideoObject],]: + ) -> Union[ + VideoObject, + Coroutine[Any, Any, VideoObject], + ]: """ Handles video generation requests. When _is_async=True, returns a coroutine instead of making the call directly. diff --git a/litellm/llms/ollama/chat/transformation.py b/litellm/llms/ollama/chat/transformation.py index 3d9618dfed..c990cc2e09 100644 --- a/litellm/llms/ollama/chat/transformation.py +++ b/litellm/llms/ollama/chat/transformation.py @@ -16,6 +16,7 @@ from httpx._models import Headers, Response from pydantic import BaseModel import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.prompt_templates.common_utils import ( _extract_reasoning_content, convert_content_list_to_str, @@ -349,7 +350,8 @@ class OllamaChatConfig(BaseConfig): response_json = raw_response.json() ## RESPONSE OBJECT - model_response.choices[0].finish_reason = "stop" + _done_reason = map_finish_reason(response_json.get("done_reason") or "stop") + model_response.choices[0].finish_reason = _done_reason response_json_message = response_json.get("message") if response_json_message is not None: if "thinking" in response_json_message: @@ -535,7 +537,7 @@ class OllamaChatCompletionResponseIterator(BaseModelResponseIterator): ) if chunk["done"] is True: - finish_reason = chunk.get("done_reason", "stop") + finish_reason = chunk.get("done_reason") or "stop" # Override finish_reason when tool_calls are present # Fixes: https://github.com/BerriAI/litellm/issues/18922 if tool_calls is not None: diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py index 5d94cd4212..ddef3810bf 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py @@ -78,6 +78,19 @@ class VertexAIPartnerModelsTokenCounter(VertexBase): return endpoint + @staticmethod + def _strip_version_suffix(model: str) -> str: + """ + Strip version suffixes (e.g. @default, @20251001) from model names. + + The Vertex AI count-tokens endpoint rejects model names that include + version suffixes — for example, "claude-sonnet-4-6@default" returns + "not supported for token counting" while "claude-sonnet-4-6" works. + """ + if "@" in model: + return model.split("@")[0] + return model + async def handle_count_tokens_request( self, model: str, @@ -98,6 +111,15 @@ class VertexAIPartnerModelsTokenCounter(VertexBase): Raises: ValueError: If required parameters are missing or invalid """ + # Strip version suffixes (@default, @20251001, etc.) — the Vertex AI + # count-tokens endpoint does not accept versioned model names. + model = self._strip_version_suffix(model) + if "model" in request_data: + request_data = { + **request_data, + "model": self._strip_version_suffix(request_data["model"]), + } + # Validate request if "messages" not in request_data: raise ValueError("messages required for token counting") diff --git a/litellm/main.py b/litellm/main.py index 22dffc0bbe..73db4a11cb 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -77,6 +77,7 @@ from litellm.litellm_core_utils.audio_utils.utils import ( calculate_request_duration, get_audio_file_for_health_check, ) +from litellm.litellm_core_utils.completion_timeout import CompletionTimeout from litellm.litellm_core_utils.dd_tracing import tracer from litellm.litellm_core_utils.get_provider_specific_headers import ( ProviderSpecificHeaderUtils, @@ -1401,14 +1402,13 @@ def completion( # type: ignore # noqa: PLR0915 ) # support region-based pricing for bedrock ### TIMEOUT LOGIC ### - timeout = timeout or kwargs.get("request_timeout", 600) or 600 - # set timeout for 10 minutes by default - if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout( - custom_llm_provider - ): - timeout = timeout.read or 600 # default 10 min timeout - elif not isinstance(timeout, httpx.Timeout): - timeout = float(timeout) # type: ignore + timeout = CompletionTimeout.resolve( + timeout, + kwargs, + custom_llm_provider, + global_timeout=getattr(litellm, "request_timeout", None), + supports_httpx_timeout=supports_httpx_timeout, + ) ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### if ( diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 2000e4e306..a1d727e1f1 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -25179,6 +25179,58 @@ "supports_web_search": true, "tpm": 800000 }, + "openrouter/google/gemini-3.1-flash-lite-preview": { + "cache_read_input_token_cost": 2.5e-08, + "cache_read_input_token_cost_per_audio_token": 5e-08, + "input_cost_per_audio_token": 5e-07, + "input_cost_per_token": 2.5e-07, + "litellm_provider": "openrouter", + "max_audio_length_hours": 8.4, + "max_audio_per_prompt": 1, + "max_images_per_prompt": 3000, + "max_input_tokens": 1048576, + "max_output_tokens": 65536, + "max_pdf_size_mb": 30, + "max_tokens": 65536, + "max_video_length": 1, + "max_videos_per_prompt": 10, + "mode": "chat", + "output_cost_per_reasoning_token": 1.5e-06, + "output_cost_per_token": 1.5e-06, + "rpm": 2000, + "source": "https://ai.google.dev/pricing/gemini-3", + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/completions", + "/v1/batch" + ], + "supported_modalities": [ + "text", + "image", + "audio", + "video" + ], + "supported_output_modalities": [ + "text" + ], + "supports_audio_input": true, + "supports_audio_output": false, + "supports_code_execution": true, + "supports_file_search": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_url_context": true, + "supports_video_input": true, + "supports_vision": true, + "supports_web_search": true, + "tpm": 800000 + }, "openrouter/google/gemini-3.1-pro-preview": { "cache_read_input_token_cost": 2e-07, "cache_read_input_token_cost_above_200k_tokens": 4e-07, diff --git a/litellm/proxy/_experimental/mcp_server/mcp_context.py b/litellm/proxy/_experimental/mcp_server/mcp_context.py index 12830db1d6..a60138dd34 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_context.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_context.py @@ -14,3 +14,8 @@ from typing import Optional _mcp_active_toolset_id: ContextVar[Optional[str]] = ContextVar( "_mcp_active_toolset_id", default=None ) + +# Per-request merged InitializeResult.instructions; set in MCP HTTP/SSE handlers. +_mcp_gateway_initialize_instructions: ContextVar[Optional[str]] = ContextVar( + "_mcp_gateway_initialize_instructions", default=None +) diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index c50bfe3ab1..8aec30bfc3 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -184,6 +184,16 @@ class MCPServerManager: "gmail_send_email": "zapier_mcp_server", } """ + self._upstream_initialize_instructions_by_server_id: Dict[str, str] = {} + + def _remember_upstream_initialize_instructions( + self, server: MCPServer, client: MCPClient + ) -> None: + raw = getattr(client, "_last_initialize_instructions", None) + if raw and str(raw).strip(): + self._upstream_initialize_instructions_by_server_id[server.server_id] = str( + raw + ).strip() def get_registry(self) -> Dict[str, MCPServer]: """ @@ -204,6 +214,7 @@ class MCPServerManager: mcp_aliases: Optional dictionary mapping aliases to server names from litellm_settings """ verbose_logger.debug("Loading MCP Servers from config-----") + self._upstream_initialize_instructions_by_server_id.clear() # Track which aliases have been used to ensure only first occurrence is used used_aliases = set() @@ -351,6 +362,7 @@ class MCPServerManager: aws_service_name=server_config.get("aws_service_name", None), aws_role_name=server_config.get("aws_role_name", None), aws_session_name=server_config.get("aws_session_name", None), + instructions=server_config.get("instructions", None), ) self.config_mcp_servers[server_id] = new_server @@ -693,6 +705,7 @@ class MCPServerManager: aws_service_name=aws_creds.get("aws_service_name"), aws_role_name=aws_creds.get("aws_role_name"), aws_session_name=aws_creds.get("aws_session_name"), + instructions=mcp_server.instructions, ) return new_server @@ -1247,6 +1260,7 @@ class MCPServerManager: return tools else: tools = await self._fetch_tools_with_timeout(client, server.name) + self._remember_upstream_initialize_instructions(server, client) prefixed_or_original_tools = self._create_prefixed_tools( tools, server, add_prefix=add_prefix @@ -2383,6 +2397,7 @@ class MCPServerManager: # If proxy_logging_obj is not None, the tool call result is at index 1 (after the during hook task) result_index = 1 if proxy_logging_obj else 0 result = mcp_responses[result_index] + self._remember_upstream_initialize_instructions(mcp_server, client) return cast(CallToolResult, result) @@ -2627,6 +2642,7 @@ class MCPServerManager: ) verbose_logger.debug("Loading MCP servers from database into registry...") + self._upstream_initialize_instructions_by_server_id.clear() # perform authz check to filter the mcp servers user has access to prisma_client = get_prisma_client_or_throw( @@ -2910,6 +2926,7 @@ class MCPServerManager: await asyncio.wait_for( client.run_with_session(_noop), timeout=MCP_HEALTH_CHECK_TIMEOUT ) + self._remember_upstream_initialize_instructions(server, client) status = "healthy" except asyncio.TimeoutError: health_check_error = ( @@ -2951,6 +2968,7 @@ class MCPServerManager: token_url=server.token_url, registration_url=server.registration_url, allow_all_keys=server.allow_all_keys, + instructions=server.instructions, ) async def get_all_mcp_servers_with_health_and_teams( @@ -3046,6 +3064,7 @@ class MCPServerManager: is_byok=server.is_byok, byok_description=server.byok_description, byok_api_key_help_url=server.byok_api_key_help_url, + instructions=server.instructions, ) async def get_all_mcp_servers_unfiltered(self) -> List[LiteLLM_MCPServerTable]: diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index 32560a2211..8131c04013 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -933,6 +933,7 @@ if MCP_AVAILABLE: authorization_url=request.authorization_url, registration_url=request.registration_url, oauth2_flow=_oauth2_flow, + instructions=request.instructions, ) stdio_env = global_mcp_server_manager._build_stdio_env( diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 99578d006e..cefe2ef763 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -7,6 +7,7 @@ LiteLLM MCP Server Routes import asyncio import contextlib import time +import types import traceback import uuid from datetime import datetime @@ -37,7 +38,10 @@ from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( get_request_base_url, ) -from litellm.proxy._experimental.mcp_server.mcp_context import _mcp_active_toolset_id +from litellm.proxy._experimental.mcp_server.mcp_context import ( + _mcp_active_toolset_id, + _mcp_gateway_initialize_instructions, +) from litellm.proxy._experimental.mcp_server.mcp_debug import MCPDebug from litellm.proxy._experimental.mcp_server.utils import ( LITELLM_MCP_SERVER_DESCRIPTION, @@ -122,6 +126,8 @@ _INITIALIZATION_LOCK = asyncio.Lock() if MCP_AVAILABLE: from mcp.server import Server + from mcp.server.lowlevel.server import NotificationOptions + from mcp.server.models import InitializationOptions # Import auth context variables and middleware from mcp.server.auth.middleware.auth_context import ( @@ -200,6 +206,21 @@ if MCP_AVAILABLE: ) return normalized + def _gateway_create_initialization_options( + self, + notification_options: Optional[NotificationOptions] = None, + experimental_capabilities: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> InitializationOptions: + opts = Server.create_initialization_options( + self, + notification_options=notification_options, + experimental_capabilities=experimental_capabilities or {}, + ) + merged = _mcp_gateway_initialize_instructions.get() + if merged is not None: + return opts.model_copy(update={"instructions": merged}) + return opts + ######################################################## ############ Initialize the MCP Server ################# ######################################################## @@ -207,6 +228,9 @@ if MCP_AVAILABLE: name=LITELLM_MCP_SERVER_NAME, version=LITELLM_MCP_SERVER_VERSION, ) + server.create_initialization_options = types.MethodType( # type: ignore[method-assign] + _gateway_create_initialization_options, server + ) sse: SseServerTransport = SseServerTransport("/mcp/sse/messages") # Create session managers @@ -1021,7 +1045,9 @@ if MCP_AVAILABLE: except (ValueError, TypeError): pass ttl = _compute_per_user_token_ttl(server, raw_expires) - await mcp_per_user_token_cache.set(user_id, server_id, access_token, ttl) + await mcp_per_user_token_cache.set( + user_id, server_id, access_token, ttl + ) return {"Authorization": f"Bearer {access_token}"} except Exception as e: @@ -1103,6 +1129,57 @@ if MCP_AVAILABLE: return server_auth_header, extra_headers + def _merge_gateway_initialize_instructions( + allowed_mcp_servers: List[MCPServer], + ) -> Optional[str]: + """YAML/DB override, else in-memory upstream text from list_tools / health_check / call_tool.""" + if not allowed_mcp_servers: + return None + + texts: List[Tuple[str, str]] = [] + for server in allowed_mcp_servers: + label = ( + server.alias + or server.server_name + or server.name + or server.server_id + or "mcp" + ) + if server.instructions and server.instructions.strip(): + texts.append((label, server.instructions.strip())) + continue + if server.spec_path: + continue + cached = global_mcp_server_manager._upstream_initialize_instructions_by_server_id.get( + server.server_id + ) + if cached and cached.strip(): + texts.append((label, cached.strip())) + + if not texts: + return None + if len(texts) == 1: + return texts[0][1] + return "\n\n---\n\n".join(f"[{lbl}]\n{txt}" for lbl, txt in texts) + + @contextlib.asynccontextmanager + async def _gateway_initialize_instructions_request_scope( + user_api_key_auth: Optional[UserAPIKeyAuth], + mcp_servers: Optional[List[str]], + client_ip: Optional[str], + ) -> AsyncIterator[None]: + allowed = await _get_allowed_mcp_servers( + user_api_key_auth=user_api_key_auth, + mcp_servers=mcp_servers, + client_ip=client_ip, + ) + merged = _merge_gateway_initialize_instructions(allowed_mcp_servers=allowed) + tok = _mcp_gateway_initialize_instructions.set(merged) + try: + yield + finally: + _mcp_gateway_initialize_instructions.reset(tok) + async def _get_tools_from_mcp_servers( # noqa: PLR0915 user_api_key_auth: Optional[UserAPIKeyAuth], mcp_auth_header: Optional[str], @@ -2670,7 +2747,12 @@ if MCP_AVAILABLE: # Request was fully handled (e.g., DELETE on non-existent session) return - await session_manager.handle_request(scope, receive, send) + async with _gateway_initialize_instructions_request_scope( + user_api_key_auth, + mcp_servers, + _client_ip, + ): + await session_manager.handle_request(scope, receive, send) except HTTPException: # Re-raise HTTP exceptions to preserve status codes and details raise @@ -2729,7 +2811,12 @@ if MCP_AVAILABLE: await initialize_session_managers() await asyncio.sleep(0.1) - await sse_session_manager.handle_request(scope, receive, send) + async with _gateway_initialize_instructions_request_scope( + user_api_key_auth, + mcp_servers, + _sse_client_ip, + ): + await sse_session_manager.handle_request(scope, receive, send) except Exception as e: verbose_logger.exception(f"Error handling MCP request: {e}") # Instead of re-raising, try to send a graceful error response diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 0bbee56d5e..7fff640c49 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -904,9 +904,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): allowed_cache_controls: Optional[list] = [] config: Optional[dict] = {} permissions: Optional[dict] = {} - model_max_budget: Optional[ - dict - ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} + model_max_budget: Optional[dict] = ( + {} + ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} model_config = ConfigDict(protected_namespaces=()) model_rpm_limit: Optional[dict] = None @@ -1048,9 +1048,9 @@ class RegenerateKeyRequest(GenerateKeyRequest): spend: Optional[float] = None metadata: Optional[dict] = None new_master_key: Optional[str] = None - grace_period: Optional[ - str - ] = None # Duration to keep old key valid (e.g. "24h", "2d"); None = immediate revoke + grace_period: Optional[str] = ( + None # Duration to keep old key valid (e.g. "24h", "2d"); None = immediate revoke + ) class ResetSpendRequest(LiteLLMPydanticObjectBase): @@ -1137,6 +1137,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase): tool_name_to_description: Optional[Dict[str, str]] = None extra_headers: Optional[List[str]] = None static_headers: Optional[Dict[str, str]] = None + instructions: Optional[str] = None # Stdio-specific fields command: Optional[str] = None args: List[str] = Field(default_factory=list) @@ -1219,6 +1220,7 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase): tool_name_to_description: Optional[Dict[str, str]] = None extra_headers: Optional[List[str]] = None static_headers: Optional[Dict[str, str]] = None + instructions: Optional[str] = None # Stdio-specific fields command: Optional[str] = None args: List[str] = Field(default_factory=list) @@ -1270,6 +1272,7 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase): transport: MCPTransportType auth_type: Optional[MCPAuthType] = None credentials: Optional[MCPCredentials] = None + instructions: Optional[str] = None created_at: Optional[datetime] = None created_by: Optional[str] = None updated_at: Optional[datetime] = None @@ -1574,12 +1577,12 @@ class NewCustomerRequest(BudgetNewRequest): blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget spend: Optional[float] = None - allowed_model_region: Optional[ - AllowedModelRegion - ] = None # require all user requests to use models in this specific region - default_model: Optional[ - str - ] = None # if no equivalent model in allowed region - default all requests to this model + allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) object_permission: Optional[LiteLLM_ObjectPermissionBase] = None @model_validator(mode="before") @@ -1602,12 +1605,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[ - AllowedModelRegion - ] = None # require all user requests to use models in this specific region - default_model: Optional[ - str - ] = None # if no equivalent model in allowed region - default all requests to this model + allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) object_permission: Optional[LiteLLM_ObjectPermissionBase] = None @@ -1697,15 +1700,15 @@ class NewTeamRequest(TeamBase): ] = None # raise an error if 'guaranteed_throughput' is set and we're overallocating tpm model_tpm_limit: Optional[Dict[str, int]] = None - team_member_budget: Optional[ - float - ] = None # allow user to set a budget for all team members - team_member_rpm_limit: Optional[ - int - ] = None # allow user to set RPM limit for all team members - team_member_tpm_limit: Optional[ - int - ] = None # allow user to set TPM limit for all team members + team_member_budget: Optional[float] = ( + None # allow user to set a budget for all team members + ) + team_member_rpm_limit: Optional[int] = ( + None # allow user to set RPM limit for all team members + ) + team_member_tpm_limit: Optional[int] = ( + None # allow user to set TPM limit for all team members + ) team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m" team_member_budget_duration: Optional[str] = None # e.g. "30d", "1mo" allowed_vector_store_indexes: Optional[List[AllowedVectorStoreIndexItem]] = None @@ -1802,9 +1805,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase): callback_name: str - callback_type: Optional[ - Literal["success", "failure", "success_and_failure"] - ] = "success_and_failure" + callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( + "success_and_failure" + ) callback_vars: Dict[str, str] @model_validator(mode="before") @@ -2146,9 +2149,9 @@ class ConfigList(LiteLLMPydanticObjectBase): stored_in_db: Optional[bool] field_default_value: Any premium_field: bool = False - nested_fields: Optional[ - List[FieldDetail] - ] = None # For nested dictionary or Pydantic fields + nested_fields: Optional[List[FieldDetail]] = ( + None # For nested dictionary or Pydantic fields + ) class UserHeaderMapping(LiteLLMPydanticObjectBase): @@ -2507,9 +2510,9 @@ class UserAPIKeyAuth( user_max_budget: Optional[float] = None request_route: Optional[str] = None user: Optional[Any] = None # Expanded user object when expand=user is used - created_by_user: Optional[ - Any - ] = None # Expanded created_by user when expand=user is used + created_by_user: Optional[Any] = ( + None # Expanded created_by user when expand=user is used + ) end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None # Decoded upstream IdP claims (groups, roles, etc.) propagated by JWT auth machinery # and forwarded into outbound tokens by guardrails such as MCPJWTSigner. @@ -2648,9 +2651,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): budget_id: Optional[str] = None created_at: datetime updated_at: datetime - user: Optional[ - Any - ] = None # You might want to replace 'Any' with a more specific type if available + user: Optional[Any] = ( + None # You might want to replace 'Any' with a more specific type if available + ) litellm_budget_table: Optional[LiteLLM_BudgetTable] = None user_email: Optional[str] = None @@ -3805,9 +3808,9 @@ class TeamModelDeleteRequest(BaseModel): # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str - max_budget_in_organization: Optional[ - float - ] = None # Users max budget within the organization + max_budget_in_organization: Optional[float] = ( + None # Users max budget within the organization + ) class OrganizationMemberDeleteRequest(MemberDeleteRequest): @@ -4062,9 +4065,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): Maps provider names to their budget configs. """ - providers: Dict[ - str, ProviderBudgetResponseObject - ] = {} # Dictionary mapping provider names to their budget configurations + providers: Dict[str, ProviderBudgetResponseObject] = ( + {} + ) # Dictionary mapping provider names to their budget configurations class ProxyStateVariables(TypedDict): @@ -4226,9 +4229,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): enforce_rbac: bool = False roles_jwt_field: Optional[str] = None # v2 on role mappings role_mappings: Optional[List[RoleMapping]] = None - object_id_jwt_field: Optional[ - str - ] = None # can be either user / team, inferred from the role mapping + object_id_jwt_field: Optional[str] = ( + None # can be either user / team, inferred from the role mapping + ) scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False enforce_team_based_model_access: bool = False diff --git a/litellm/proxy/credential_endpoints/endpoints.py b/litellm/proxy/credential_endpoints/endpoints.py index 64f860fc4f..4bf6219db4 100644 --- a/litellm/proxy/credential_endpoints/endpoints.py +++ b/litellm/proxy/credential_endpoints/endpoints.py @@ -340,6 +340,35 @@ async def update_credential( "updated_by": user_api_key_dict.user_id, }, ) + + # Sync in-memory credential_list (skip if not in memory - e.g., proxy restarted) + new_name = merged_credential.credential_name + existing_in_memory: Optional[CredentialItem] = None + for cred in litellm.credential_list: + if cred.credential_name == credential_name: + existing_in_memory = cred + break + + if existing_in_memory is not None: + in_memory_values = dict(existing_in_memory.credential_values or {}) + if credential.credential_values: + in_memory_values.update(credential.credential_values) + in_memory_info = dict(existing_in_memory.credential_info or {}) + if credential.credential_info: + in_memory_info.update(credential.credential_info) + updated_in_memory = CredentialItem( + credential_name=new_name, + credential_values=in_memory_values, + credential_info=in_memory_info, + ) + # Remove old entry if renamed, then use upsert_credentials to handle duplicates + if new_name != credential_name: + litellm.credential_list = [ + c for c in litellm.credential_list + if c.credential_name != credential_name + ] + CredentialAccessor.upsert_credentials([updated_in_memory]) + return {"success": True, "message": "Credential updated successfully"} except Exception as e: return handle_exception_on_proxy(e) diff --git a/litellm/proxy/db/create_views.py b/litellm/proxy/db/create_views.py index fd0baf67b3..2326f495a5 100644 --- a/litellm/proxy/db/create_views.py +++ b/litellm/proxy/db/create_views.py @@ -4,6 +4,12 @@ from litellm import verbose_logger _db = Any +# Markers that indicate a view/relation does not yet exist in the database. +# Keeping these in one place avoids repeating the check across all view blocks +# and prevents overly broad matches (e.g. bare 'undefined' would also match +# 'undefined function' or 'column undefined_col referenced in query'). +_VIEW_NOT_FOUND_MARKERS = ("does not exist", "no such table", "undefined table") + async def create_missing_views(db: _db): # noqa: PLR0915 """ @@ -18,14 +24,17 @@ async def create_missing_views(db: _db): # noqa: PLR0915 If the view doesn't exist, one will be created. """ + try: # Try to select one row from the view await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""") - print("LiteLLM_VerificationTokenView Exists!") # noqa - except Exception: + verbose_logger.debug("LiteLLM_VerificationTokenView Exists!") + except Exception as e: + error_msg = str(e).lower() + if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS): + raise # If an error occurs, the view does not exist, so create it - await db.execute_raw( - """ + await db.execute_raw(""" CREATE VIEW "LiteLLM_VerificationTokenView" AS SELECT v.*, @@ -37,15 +46,17 @@ async def create_missing_views(db: _db): # noqa: PLR0915 FROM "LiteLLM_VerificationToken" v LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id; - """ - ) + """) - print("LiteLLM_VerificationTokenView Created!") # noqa + verbose_logger.debug("LiteLLM_VerificationTokenView Created!") try: await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""") - print("MonthlyGlobalSpend Exists!") # noqa - except Exception: + verbose_logger.debug("MonthlyGlobalSpend Exists!") + except Exception as e: + error_msg = str(e).lower() + if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS): + raise sql_query = """ CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS SELECT @@ -60,12 +71,15 @@ async def create_missing_views(db: _db): # noqa: PLR0915 """ await db.execute_raw(query=sql_query) - print("MonthlyGlobalSpend Created!") # noqa + verbose_logger.debug("MonthlyGlobalSpend Created!") try: await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""") - print("Last30dKeysBySpend Exists!") # noqa - except Exception: + verbose_logger.debug("Last30dKeysBySpend Exists!") + except Exception as e: + error_msg = str(e).lower() + if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS): + raise sql_query = """ CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS SELECT @@ -88,12 +102,15 @@ async def create_missing_views(db: _db): # noqa: PLR0915 """ await db.execute_raw(query=sql_query) - print("Last30dKeysBySpend Created!") # noqa + verbose_logger.debug("Last30dKeysBySpend Created!") try: await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""") - print("Last30dModelsBySpend Exists!") # noqa - except Exception: + verbose_logger.debug("Last30dModelsBySpend Exists!") + except Exception as e: + error_msg = str(e).lower() + if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS): + raise sql_query = """ CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS SELECT @@ -111,11 +128,14 @@ async def create_missing_views(db: _db): # noqa: PLR0915 """ await db.execute_raw(query=sql_query) - print("Last30dModelsBySpend Created!") # noqa + verbose_logger.debug("Last30dModelsBySpend Created!") try: await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""") - print("MonthlyGlobalSpendPerKey Exists!") # noqa - except Exception: + verbose_logger.debug("MonthlyGlobalSpendPerKey Exists!") + except Exception as e: + error_msg = str(e).lower() + if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS): + raise sql_query = """ CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS SELECT @@ -132,13 +152,16 @@ async def create_missing_views(db: _db): # noqa: PLR0915 """ await db.execute_raw(query=sql_query) - print("MonthlyGlobalSpendPerKey Created!") # noqa + verbose_logger.debug("MonthlyGlobalSpendPerKey Created!") try: await db.query_raw( """SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1""" ) - print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa - except Exception: + verbose_logger.debug("MonthlyGlobalSpendPerUserPerKey Exists!") + except Exception as e: + error_msg = str(e).lower() + if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS): + raise sql_query = """ CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS SELECT @@ -157,12 +180,15 @@ async def create_missing_views(db: _db): # noqa: PLR0915 """ await db.execute_raw(query=sql_query) - print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa + verbose_logger.debug("MonthlyGlobalSpendPerUserPerKey Created!") try: await db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""") - print("DailyTagSpend Exists!") # noqa - except Exception: + verbose_logger.debug("DailyTagSpend Exists!") + except Exception as e: + error_msg = str(e).lower() + if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS): + raise sql_query = """ CREATE OR REPLACE VIEW "DailyTagSpend" AS SELECT @@ -175,12 +201,15 @@ async def create_missing_views(db: _db): # noqa: PLR0915 """ await db.execute_raw(query=sql_query) - print("DailyTagSpend Created!") # noqa + verbose_logger.debug("DailyTagSpend Created!") try: await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""") - print("Last30dTopEndUsersSpend Exists!") # noqa - except Exception: + verbose_logger.debug("Last30dTopEndUsersSpend Exists!") + except Exception as e: + error_msg = str(e).lower() + if not any(marker in error_msg for marker in _VIEW_NOT_FOUND_MARKERS): + raise sql_query = """ CREATE VIEW "Last30dTopEndUsersSpend" AS SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend @@ -193,7 +222,7 @@ async def create_missing_views(db: _db): # noqa: PLR0915 """ await db.execute_raw(query=sql_query) - print("Last30dTopEndUsersSpend Created!") # noqa + verbose_logger.debug("Last30dTopEndUsersSpend Created!") return diff --git a/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py b/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py index 6435498ae0..5e0ddef9ea 100644 --- a/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py +++ b/litellm/proxy/db/db_transaction_queue/pod_lock_manager.py @@ -21,9 +21,18 @@ class PodLockManager: Ensures that only one pod can run a cron job at a time. """ + _COMPARE_AND_DELETE_LOCK_SCRIPT = """ +if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("del", KEYS[1]) +else + return 0 +end +""" + def __init__(self, redis_cache: Optional[RedisCache] = None): self.pod_id = str(uuid.uuid4()) self.redis_cache = redis_cache + self._release_lock_script: Optional[Any] = None @staticmethod def get_redis_lock_key(cronjob_id: str) -> str: @@ -107,53 +116,35 @@ class PodLockManager: ): """ Release the lock if the current pod holds it. - Uses get and delete commands to ensure that only the owner can release the lock. + Uses an atomic Lua compare-and-delete to prevent TOCTOU races where a + stale owner could delete a newly reacquired lock. + Falls back to GET + DEL for cache implementations that don't support + script registration. """ if self.redis_cache is None: verbose_proxy_logger.debug("redis_cache is None, skipping release_lock") return try: - cronjob_id = cronjob_id verbose_proxy_logger.debug( "Pod %s attempting to release Redis lock for cronjob_id=%s", self.pod_id, cronjob_id, ) lock_key = PodLockManager.get_redis_lock_key(cronjob_id) - - current_value = await self.redis_cache.async_get_cache(lock_key) - if current_value is not None: - if isinstance(current_value, bytes): - current_value = current_value.decode("utf-8") - if current_value == self.pod_id: - result = await self.redis_cache.async_delete_cache(lock_key) - if result == 1: - verbose_proxy_logger.info( - "Pod %s successfully released Redis lock for cronjob_id=%s", - self.pod_id, - cronjob_id, - ) - self._emit_released_lock_event( - cronjob_id=cronjob_id, - pod_id=self.pod_id, - ) - else: - verbose_proxy_logger.warning( - "Pod %s failed to release Redis lock for cronjob_id=%s. " - "Lock will expire after its TTL.", - self.pod_id, - cronjob_id, - ) - else: - verbose_proxy_logger.debug( - "Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s", - self.pod_id, - cronjob_id, - current_value, - ) + result = await self._compare_and_delete_lock(lock_key=lock_key) + if result == 1: + verbose_proxy_logger.info( + "Pod %s successfully released Redis lock for cronjob_id=%s", + self.pod_id, + cronjob_id, + ) + self._emit_released_lock_event( + cronjob_id=cronjob_id, + pod_id=self.pod_id, + ) else: verbose_proxy_logger.debug( - "Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found", + "Pod %s failed to release Redis lock for cronjob_id=%s (lock missing or held by another pod)", self.pod_id, cronjob_id, ) @@ -162,6 +153,42 @@ class PodLockManager: f"Error releasing Redis lock for {cronjob_id}: {e}" ) + async def _compare_and_delete_lock(self, lock_key: str) -> int: + """ + Atomically delete lock key only if current pod owns it. + + Falls back to get/delete for non-RedisCache implementations that do not + expose Lua script registration. + """ + script_register = getattr(self.redis_cache, "async_register_script", None) + if callable(script_register): + try: + if self._release_lock_script is None: + self._release_lock_script = script_register( + self._COMPARE_AND_DELETE_LOCK_SCRIPT + ) + result = await self._release_lock_script( + keys=[lock_key], args=[self.pod_id] + ) + return int(result or 0) + except Exception: + # Lua execution failed (e.g. Redis restart cleared loaded scripts, + # or scripting is disabled). Reset cached script handle and fall + # through to the GET + DEL fallback so the lock is still released. + self._release_lock_script = None + verbose_proxy_logger.warning( + "Lua compare-and-delete failed for lock_key=%s, falling back to GET+DEL", + lock_key, + ) + + current_value = await self.redis_cache.async_get_cache(lock_key) # type: ignore + if isinstance(current_value, bytes): + current_value = current_value.decode("utf-8") + if current_value != self.pod_id: + return 0 + result = await self.redis_cache.async_delete_cache(lock_key) # type: ignore + return int(result or 0) + @staticmethod def _emit_acquired_lock_event(cronjob_id: str, pod_id: str): asyncio.create_task( diff --git a/litellm/proxy/db/db_transaction_queue/spend_log_cleanup.py b/litellm/proxy/db/db_transaction_queue/spend_log_cleanup.py index ba9423c6ef..bc9efb52b0 100644 --- a/litellm/proxy/db/db_transaction_queue/spend_log_cleanup.py +++ b/litellm/proxy/db/db_transaction_queue/spend_log_cleanup.py @@ -82,7 +82,7 @@ class SpendLogCleanup: break # Step 1: Find logs and delete them in one go without fetching to application # Delete in batches, limited by self.batch_size - deleted_count = await prisma_client.db.execute_raw( + deleted_result = await prisma_client.db.execute_raw( """ DELETE FROM "LiteLLM_SpendLogs" WHERE "request_id" IN ( @@ -94,6 +94,17 @@ class SpendLogCleanup: cutoff_date, self.batch_size, ) + + deleted_count = 0 + if isinstance(deleted_result, int): + deleted_count = deleted_result + else: + verbose_proxy_logger.error( + f"Unexpected execute_raw return type for spend log cleanup: {type(deleted_result)}; " + "aborting cleanup to avoid infinite loop" + ) + break + verbose_proxy_logger.info(f"Deleted {deleted_count} logs in this batch") if deleted_count == 0: diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index b4b8e68113..ed34ec6682 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -83,7 +83,12 @@ def _redact_pii_matches(response_json: dict) -> dict: redacted_response = copy.deepcopy(response_json) # Get assessments from the response - assessments = redacted_response.get("assessments", []) + # NOTE: We use `.get("key") or []` instead of `.get("key", [])` because + # the Bedrock API can return explicit `null` for list fields (e.g. "regexes": null). + # In Python, dict.get("key", []) returns None (not []) when the key exists + # with a None/null value. The `or []` ensures we always get an iterable, + # preventing "TypeError: 'NoneType' object is not iterable". + assessments = redacted_response.get("assessments") or [] if not assessments: return redacted_response @@ -91,13 +96,13 @@ def _redact_pii_matches(response_json: dict) -> dict: # Redact PII entities in sensitive information policy sensitive_info_policy = assessment.get("sensitiveInformationPolicy") if sensitive_info_policy: - pii_entities = sensitive_info_policy.get("piiEntities", []) + pii_entities = sensitive_info_policy.get("piiEntities") or [] for pii_entity in pii_entities: if "match" in pii_entity: pii_entity["match"] = "[REDACTED]" # Redact regex matches - regexes = sensitive_info_policy.get("regexes", []) + regexes = sensitive_info_policy.get("regexes") or [] for regex_match in regexes: if "match" in regex_match: regex_match["match"] = "[REDACTED]" @@ -105,12 +110,12 @@ def _redact_pii_matches(response_json: dict) -> dict: # Redact custom word matches in word policy word_policy = assessment.get("wordPolicy") if word_policy: - custom_words = word_policy.get("customWords", []) + custom_words = word_policy.get("customWords") or [] for custom_word in custom_words: if "match" in custom_word: custom_word["match"] = "[REDACTED]" - managed_words = word_policy.get("managedWordLists", []) + managed_words = word_policy.get("managedWordLists") or [] for managed_word in managed_words: if "match" in managed_word: managed_word["match"] = "[REDACTED]" @@ -825,7 +830,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): return False # Check assessments to determine if any actions were BLOCKED (vs ANONYMIZED) - assessments = response.get("assessments", []) + # NOTE: Use `or []` instead of default param to handle explicit null from Bedrock API. + # See _redact_pii_matches() for detailed explanation of the null safety pattern. + assessments = response.get("assessments") or [] if not assessments: return False @@ -833,7 +840,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): # Check topic policy topic_policy = assessment.get("topicPolicy") if topic_policy: - topics = topic_policy.get("topics", []) + topics = topic_policy.get("topics") or [] for topic in topics: if topic.get("action") == "BLOCKED": return True @@ -841,7 +848,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): # Check content policy content_policy = assessment.get("contentPolicy") if content_policy: - filters = content_policy.get("filters", []) + filters = content_policy.get("filters") or [] for filter_item in filters: if filter_item.get("action") == "BLOCKED": return True @@ -849,11 +856,11 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): # Check word policy word_policy = assessment.get("wordPolicy") if word_policy: - custom_words = word_policy.get("customWords", []) + custom_words = word_policy.get("customWords") or [] for custom_word in custom_words: if custom_word.get("action") == "BLOCKED": return True - managed_words = word_policy.get("managedWordLists", []) + managed_words = word_policy.get("managedWordLists") or [] for managed_word in managed_words: if managed_word.get("action") == "BLOCKED": return True @@ -861,12 +868,12 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): # Check sensitive information policy sensitive_info_policy = assessment.get("sensitiveInformationPolicy") if sensitive_info_policy: - pii_entities = sensitive_info_policy.get("piiEntities", []) + pii_entities = sensitive_info_policy.get("piiEntities") or [] if pii_entities: for pii_entity in pii_entities: if pii_entity.get("action") == "BLOCKED": return True - regexes = sensitive_info_policy.get("regexes", []) + regexes = sensitive_info_policy.get("regexes") or [] if regexes: for regex in regexes: if regex.get("action") == "BLOCKED": @@ -875,7 +882,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): # Check contextual grounding policy contextual_grounding_policy = assessment.get("contextualGroundingPolicy") if contextual_grounding_policy: - grounding_filters = contextual_grounding_policy.get("filters", []) + grounding_filters = contextual_grounding_policy.get("filters") or [] for grounding_filter in grounding_filters: if grounding_filter.get("action") == "BLOCKED": return True @@ -1534,7 +1541,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): Raises: Exception: If content is blocked by Bedrock guardrail """ - texts = inputs.get("texts", []) + # NOTE: Use `or []` to handle case where inputs["texts"] is explicitly None. + # dict.get("texts", []) would return None if the key exists with a None value. + texts = inputs.get("texts") or [] try: verbose_proxy_logger.debug( f"Bedrock Guardrail: Applying guardrail to {len(texts)} text(s)" diff --git a/litellm/proxy/guardrails/guardrail_hooks/noma/noma_v2.py b/litellm/proxy/guardrails/guardrail_hooks/noma/noma_v2.py index 1a119ec56b..071613ad5f 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/noma/noma_v2.py +++ b/litellm/proxy/guardrails/guardrail_hooks/noma/noma_v2.py @@ -274,6 +274,15 @@ class NomaV2Guardrail(CustomGuardrail): if application_id is None: application_id = self._get_non_empty_str(self.application_id) + # Fall back to API key alias for per-key traceability in Noma dashboard + # (ports v1 fallback from PR #16832). + if application_id is None: + application_id = self._get_non_empty_str( + request_data.get("litellm_metadata", {}).get("user_api_key_alias") + ) or self._get_non_empty_str( + request_data.get("metadata", {}).get("user_api_key_alias") + ) + try: payload = self._build_scan_payload( inputs=inputs, diff --git a/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py b/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py index a1623121da..367e6b2f15 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py +++ b/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py @@ -52,6 +52,20 @@ def _get_a2a_request_id( endpoint_guardrail_translation_mappings = None +def _ensure_litellm_metadata(data: dict, user_api_key_dict: UserAPIKeyAuth) -> None: + """Populate data['litellm_metadata'] from user_api_key_dict if absent.""" + if "litellm_metadata" not in data: + from litellm.llms.base_llm.guardrail_translation.base_translation import ( + BaseTranslation, + ) + + user_metadata = BaseTranslation.transform_user_api_key_dict_to_metadata( + user_api_key_dict + ) + if user_metadata: + data["litellm_metadata"] = user_metadata + + class UnifiedLLMGuardrails(CustomLogger): def __init__( self, @@ -120,6 +134,8 @@ class UnifiedLLMGuardrails(CustomLogger): CallTypes(call_type) ]() + _ensure_litellm_metadata(data, user_api_key_dict) + data = await endpoint_translation.process_input_messages( data=data, guardrail_to_apply=guardrail_to_apply, @@ -177,6 +193,8 @@ class UnifiedLLMGuardrails(CustomLogger): CallTypes(call_type) ]() + _ensure_litellm_metadata(data, user_api_key_dict) + return await endpoint_translation.process_input_messages( data=data, guardrail_to_apply=guardrail_to_apply, diff --git a/litellm/proxy/health_check.py b/litellm/proxy/health_check.py index 5d1bcf31f8..1518ed66ab 100644 --- a/litellm/proxy/health_check.py +++ b/litellm/proxy/health_check.py @@ -11,7 +11,11 @@ from typing import List, Optional import litellm logger = logging.getLogger(__name__) -from litellm.constants import DEFAULT_HEALTH_CHECK_PROMPT, HEALTH_CHECK_TIMEOUT_SECONDS +from litellm.constants import ( + BACKGROUND_HEALTH_CHECK_MAX_TOKENS, + DEFAULT_HEALTH_CHECK_PROMPT, + HEALTH_CHECK_TIMEOUT_SECONDS, +) ILLEGAL_DISPLAY_PARAMS = [ "messages", @@ -242,7 +246,9 @@ async def _perform_health_check( cleaned["model_id"] = _model_id if isinstance(is_healthy, Exception): exceptions_by_model_id[_model_id] = is_healthy - cleaned["exception_status"] = getattr(is_healthy, "status_code", 500) + cleaned["exception_status"] = getattr( + is_healthy, "status_code", 500 + ) unhealthy_endpoints.append(cleaned) return healthy_endpoints, unhealthy_endpoints, exceptions_by_model_id @@ -301,6 +307,8 @@ def _update_litellm_params_for_health_check( _health_check_max_tokens = model_info.get("health_check_max_tokens", None) if _health_check_max_tokens is not None: litellm_params["max_tokens"] = _health_check_max_tokens + elif BACKGROUND_HEALTH_CHECK_MAX_TOKENS is not None: + litellm_params["max_tokens"] = BACKGROUND_HEALTH_CHECK_MAX_TOKENS elif "*" not in ( model_info.get("health_check_model") or litellm_params.get("model") or "" ): diff --git a/litellm/proxy/health_check_utils/shared_health_check_manager.py b/litellm/proxy/health_check_utils/shared_health_check_manager.py index 2ecee5095b..5b8370fece 100644 --- a/litellm/proxy/health_check_utils/shared_health_check_manager.py +++ b/litellm/proxy/health_check_utils/shared_health_check_manager.py @@ -84,7 +84,7 @@ class SharedHealthCheckManager: "Pod %s failed to acquire health check lock", self.pod_id ) - return acquired + return bool(acquired) except Exception as e: verbose_proxy_logger.error("Error acquiring health check lock: %s", str(e)) return False diff --git a/litellm/proxy/management_helpers/team_member_permission_checks.py b/litellm/proxy/management_helpers/team_member_permission_checks.py index 7dd99d4ff1..e035168ca0 100644 --- a/litellm/proxy/management_helpers/team_member_permission_checks.py +++ b/litellm/proxy/management_helpers/team_member_permission_checks.py @@ -98,11 +98,20 @@ class TeamMemberPermissionChecks: ) # 5. Check if the team member has permissions for the endpoint - TeamMemberPermissionChecks.does_team_member_have_permissions_for_endpoint( - team_member_object=key_assigned_user_in_team, - team_table=team_table, - route=route, + has_permission = ( + TeamMemberPermissionChecks.does_team_member_have_permissions_for_endpoint( + team_member_object=key_assigned_user_in_team, + team_table=team_table, + route=route, + ) ) + if not has_permission: + raise ProxyException( + message=f"User {user_api_key_dict.user_id} does not belong to team {team_table.team_id}. Team-scoped key management endpoints can only be used for keys in your own team.", + type=ProxyErrorTypes.team_member_permission_error, + param=route, + code=401, + ) @staticmethod def does_team_member_have_permissions_for_endpoint( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9981c049c1..2d789b982d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -54,7 +54,7 @@ from litellm.constants import ( LITELLM_SETTINGS_SAFE_DB_OVERRIDES, LITELLM_UI_ALLOW_HEADERS, LITELLM_UI_SESSION_DURATION, - DAILY_TAG_SPEND_BATCH_MULTIPLIER + DAILY_TAG_SPEND_BATCH_MULTIPLIER, ) from litellm.litellm_core_utils.litellm_logging import ( _init_custom_logger_compatible_class, @@ -1139,7 +1139,54 @@ async def openai_exception_handler(request: Request, exc: ProxyException): router = APIRouter() -origins = ["*"] + + +def _get_cors_config( + cors_origins_env: Optional[str] = None, + cors_credentials_env: Optional[str] = None, +): + """ + Compute CORS allowed origins and credentials flag from environment variables. + + Extracted into a function so it can be unit-tested without reloading the module. + + Args: + cors_origins_env: Value of LITELLM_CORS_ORIGINS (defaults to os.getenv). + cors_credentials_env: Value of LITELLM_CORS_ALLOW_CREDENTIALS (defaults to os.getenv). + + Returns: + Tuple[List[str], bool]: (origins, allow_credentials) + """ + _origins_raw = ( + cors_origins_env + if cors_origins_env is not None + else os.getenv("LITELLM_CORS_ORIGINS") + ) + if _origins_raw is None or _origins_raw.strip() == "": + computed_origins = ["*"] + else: + computed_origins = [o.strip() for o in _origins_raw.split(",") if o.strip()] + + # Disable credentials by default when wildcard origins are used — combining + # allow_origins=["*"] with allow_credentials=True causes Starlette to reflect + # the incoming Origin header, allowing any site to make credentialed requests. + # Set LITELLM_CORS_ALLOW_CREDENTIALS=true to explicitly restore the old behaviour + # (e.g. for non-browser clients that relied on the Access-Control-Allow-Credentials + # header being present regardless of origin). + _credentials_raw = ( + cors_credentials_env + if cors_credentials_env is not None + else os.getenv("LITELLM_CORS_ALLOW_CREDENTIALS") + ) + if _credentials_raw is not None: + computed_credentials = _credentials_raw.strip().lower() == "true" + else: + computed_credentials = "*" not in computed_origins + + return computed_origins, computed_credentials + + +origins, allow_cors_credentials = _get_cors_config() # get current directory @@ -1466,7 +1513,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__)) app.add_middleware( CORSMiddleware, allow_origins=origins, - allow_credentials=True, + allow_credentials=allow_cors_credentials, allow_methods=["*"], allow_headers=["*"], expose_headers=LITELLM_UI_ALLOW_HEADERS, @@ -1870,26 +1917,20 @@ async def update_cache( # noqa: PLR0915 ## CHECK IF USER PROJECTED SPEND > SOFT LIMIT if ( existing_spend_obj.soft_budget_cooldown is False - and existing_spend_obj.litellm_budget_table is not None + and existing_spend_obj.soft_budget is not None and ( _is_projected_spend_over_limit( current_spend=new_spend, - soft_budget_limit=existing_spend_obj.litellm_budget_table[ - "soft_budget" - ], + soft_budget_limit=existing_spend_obj.soft_budget, ) is True ) ): projected_spend, projected_exceeded_date = _get_projected_spend_over_limit( current_spend=new_spend, - soft_budget_limit=existing_spend_obj.litellm_budget_table.get( - "soft_budget", None - ), + soft_budget_limit=existing_spend_obj.soft_budget, ) # type: ignore - soft_limit = existing_spend_obj.litellm_budget_table.get( - "soft_budget", float("inf") - ) + soft_limit = existing_spend_obj.soft_budget call_info = CallInfo( token=existing_spend_obj.token or "", spend=new_spend, @@ -1897,7 +1938,7 @@ async def update_cache( # noqa: PLR0915 max_budget=soft_limit, user_id=existing_spend_obj.user_id, projected_spend=projected_spend, - projected_exceeded_date=projected_exceeded_date, + projected_exceeded_date=str(projected_exceeded_date), event_group=Litellm_EntityType.KEY, ) # alert user @@ -2315,9 +2356,13 @@ def _write_health_state_to_router_cache( exception_status = getattr(original_exception, "status_code", 500) - if llm_router.health_check_ignore_transient_errors and exception_status in ( - 429, - 408, + if ( + llm_router.health_check_ignore_transient_errors + and exception_status + in ( + 429, + 408, + ) ): continue @@ -6286,7 +6331,9 @@ class ProxyStartupEvent: ### UPDATE DAILY TAG SPEND (separate scheduler job with longer interval) ### ## Reduces QPS as there are more tags for a single request - tag_spend_update_interval = int(batch_writing_interval * DAILY_TAG_SPEND_BATCH_MULTIPLIER) + tag_spend_update_interval = int( + batch_writing_interval * DAILY_TAG_SPEND_BATCH_MULTIPLIER + ) from litellm.proxy.utils import update_daily_tag_spend scheduler.add_job( @@ -7131,9 +7178,9 @@ async def chat_completion( # noqa: PLR0915 hasattr(user_api_key_dict, "organization_alias") and user_api_key_dict.organization_alias is not None ): - data["metadata"]["user_api_key_org_alias"] = ( - user_api_key_dict.organization_alias - ) + data["metadata"][ + "user_api_key_org_alias" + ] = user_api_key_dict.organization_alias if ( hasattr(user_api_key_dict, "agent_id") and user_api_key_dict.agent_id is not None @@ -7312,9 +7359,9 @@ async def completion( # noqa: PLR0915 hasattr(user_api_key_dict, "organization_alias") and user_api_key_dict.organization_alias is not None ): - data["metadata"]["user_api_key_org_alias"] = ( - user_api_key_dict.organization_alias - ) + data["metadata"][ + "user_api_key_org_alias" + ] = user_api_key_dict.organization_alias if ( hasattr(user_api_key_dict, "agent_id") and user_api_key_dict.agent_id is not None @@ -7561,9 +7608,9 @@ async def embeddings( # noqa: PLR0915 hasattr(user_api_key_dict, "organization_alias") and user_api_key_dict.organization_alias is not None ): - data["metadata"]["user_api_key_org_alias"] = ( - user_api_key_dict.organization_alias - ) + data["metadata"][ + "user_api_key_org_alias" + ] = user_api_key_dict.organization_alias if ( hasattr(user_api_key_dict, "agent_id") and user_api_key_dict.agent_id is not None diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index fce95465b5..ce3f5f131f 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -289,6 +289,7 @@ model LiteLLM_MCPServerTable { server_name String? alias String? description String? + instructions String? url String? spec_path String? transport String @default("sse") @@ -1045,6 +1046,7 @@ model LiteLLM_HealthCheckTable { @@index([model_name]) @@index([checked_at]) @@index([status]) + @@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx") } // Search Tools table for storing search tool configurations diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index ac3bf6d498..9dfe14a562 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -4526,29 +4526,20 @@ class PrismaClient: async def get_all_latest_health_checks(self): """ - Get the latest health check for each model + Get the latest health check for each model. + + Uses DB-level DISTINCT ON (model_id, model_name) with ORDER BY checked_at DESC + (via Prisma ``distinct`` + ``order``) so we never load the full history into memory. """ try: - # Get all unique model names first - all_checks = await self.db.litellm_healthchecktable.find_many( - order={"checked_at": "desc"} + return await self.db.litellm_healthchecktable.find_many( + distinct=["model_id", "model_name"], + order=[ + {"model_id": "asc"}, + {"model_name": "asc"}, + {"checked_at": "desc"}, + ], ) - - # Group by model_name and get the latest for each - latest_checks = {} - for check in all_checks: - # Create a unique key: prefer model_id if available, otherwise use model_name - # This ensures we get the latest check for each unique model - if check.model_id: - key = (check.model_id, check.model_name) - else: - key = (None, check.model_name) - - # Only add if we haven't seen this key yet (since checks are ordered by checked_at desc) - if key not in latest_checks: - latest_checks[key] = check - - return list(latest_checks.values()) except Exception as e: verbose_proxy_logger.error(f"Error getting all latest health checks: {e}") return [] @@ -5322,19 +5313,6 @@ def get_error_message_str(e: Exception) -> str: return error_message -def _get_openapi_url() -> Optional[str]: - """ - Get the OpenAPI schema URL from the environment variables. - - - If NO_OPENAPI is True, return None. - - Otherwise, default to "/openapi.json". - """ - if str_to_bool(os.getenv("NO_OPENAPI")) is True: - return None - - return "/openapi.json" - - def _get_redoc_url() -> Optional[str]: """ Get the Redoc URL from the environment variables. @@ -5368,6 +5346,22 @@ def _get_docs_url() -> Optional[str]: return "/" +def _get_openapi_url() -> Optional[str]: + """ + Get the OpenAPI JSON URL from the environment variables. + + - If OPENAPI_URL is set, return it. + - If NO_OPENAPI is True, return None. + - Otherwise, default to "/openapi.json". + """ + if openapi_url := os.getenv("OPENAPI_URL"): + return openapi_url + + if str_to_bool(os.getenv("NO_OPENAPI")) is True: + return None + + return "/openapi.json" + def handle_exception_on_proxy(e: Exception) -> ProxyException: """ diff --git a/litellm/types/integrations/prometheus.py b/litellm/types/integrations/prometheus.py index 51a41f97e0..b1535208ec 100644 --- a/litellm/types/integrations/prometheus.py +++ b/litellm/types/integrations/prometheus.py @@ -396,6 +396,7 @@ class PrometheusMetricLabels: UserAPIKeyLabelNames.CLIENT_IP.value, UserAPIKeyLabelNames.USER_AGENT.value, UserAPIKeyLabelNames.MODEL_ID.value, + UserAPIKeyLabelNames.API_PROVIDER.value, ] litellm_spend_metric = [ @@ -410,6 +411,7 @@ class PrometheusMetricLabels: UserAPIKeyLabelNames.CLIENT_IP.value, UserAPIKeyLabelNames.USER_AGENT.value, UserAPIKeyLabelNames.MODEL_ID.value, + UserAPIKeyLabelNames.API_PROVIDER.value, ] litellm_input_tokens_metric = [ diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index a7d0968c0e..ace8c8a418 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -27,20 +27,21 @@ class MCPServer(BaseModel): spec_path: Optional[str] = None auth_type: Optional[MCPAuthType] = None authentication_token: Optional[str] = None + instructions: Optional[str] = None mcp_info: Optional[MCPInfo] = None - extra_headers: Optional[ - List[str] - ] = None # allow admin to specify which headers to forward from client to the MCP server + extra_headers: Optional[List[str]] = ( + None # allow admin to specify which headers to forward from client to the MCP server + ) allowed_tools: Optional[List[str]] = None disallowed_tools: Optional[List[str]] = None tool_name_to_display_name: Optional[Dict[str, str]] = None tool_name_to_description: Optional[Dict[str, str]] = None - allowed_params: Optional[ - Dict[str, List[str]] - ] = None # map of tool names to allowed parameter lists - static_headers: Optional[ - Dict[str, str] - ] = None # static headers to forward to the MCP server + allowed_params: Optional[Dict[str, List[str]]] = ( + None # map of tool names to allowed parameter lists + ) + static_headers: Optional[Dict[str, str]] = ( + None # static headers to forward to the MCP server + ) # OAuth-specific fields client_id: Optional[str] = None client_secret: Optional[str] = None diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 0e3bc8f3ed..d0bc9b7894 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -2647,6 +2647,8 @@ class StandardLoggingAdditionalHeaders(TypedDict, total=False): x_ratelimit_limit_tokens: int x_ratelimit_remaining_requests: int x_ratelimit_remaining_tokens: int + x_ratelimit_reset_requests: str + x_ratelimit_reset_tokens: str class StandardLoggingHiddenParams(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index 8d6ea5b7e9..2125875ee1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4815,11 +4815,21 @@ def _apply_openai_param_overrides( If user passes in allowed_openai_params, apply them to optional_params These params will get passed as is to the LLM API since the user opted in to passing them in the request + + Only params the caller actually sent are forwarded. Previously this + function unconditionally wrote `None` for any allowed param missing from + the request, which then reached the provider SDK as a top-level kwarg it + did not recognize (e.g. openai SDK raising + `AsyncCompletions.create() got an unexpected keyword argument 'enable_thinking'`). + See https://github.com/BerriAI/litellm/issues/25697 """ if allowed_openai_params: for param in allowed_openai_params: - if param not in optional_params: - optional_params[param] = non_default_params.pop(param, None) + if param in optional_params: + continue + if param not in non_default_params: + continue + optional_params[param] = non_default_params.pop(param) return optional_params diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index c624736d6b..4c807f8270 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -25164,6 +25164,58 @@ "supports_web_search": true, "tpm": 800000 }, + "openrouter/google/gemini-3.1-flash-lite-preview": { + "cache_read_input_token_cost": 2.5e-08, + "cache_read_input_token_cost_per_audio_token": 5e-08, + "input_cost_per_audio_token": 5e-07, + "input_cost_per_token": 2.5e-07, + "litellm_provider": "openrouter", + "max_audio_length_hours": 8.4, + "max_audio_per_prompt": 1, + "max_images_per_prompt": 3000, + "max_input_tokens": 1048576, + "max_output_tokens": 65536, + "max_pdf_size_mb": 30, + "max_tokens": 65536, + "max_video_length": 1, + "max_videos_per_prompt": 10, + "mode": "chat", + "output_cost_per_reasoning_token": 1.5e-06, + "output_cost_per_token": 1.5e-06, + "rpm": 2000, + "source": "https://ai.google.dev/pricing/gemini-3", + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/completions", + "/v1/batch" + ], + "supported_modalities": [ + "text", + "image", + "audio", + "video" + ], + "supported_output_modalities": [ + "text" + ], + "supports_audio_input": true, + "supports_audio_output": false, + "supports_code_execution": true, + "supports_file_search": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_url_context": true, + "supports_video_input": true, + "supports_vision": true, + "supports_web_search": true, + "tpm": 800000 + }, "openrouter/google/gemini-3.1-pro-preview": { "cache_read_input_token_cost": 2e-07, "cache_read_input_token_cost_above_200k_tokens": 4e-07, diff --git a/schema.prisma b/schema.prisma index fce95465b5..ce3f5f131f 100644 --- a/schema.prisma +++ b/schema.prisma @@ -289,6 +289,7 @@ model LiteLLM_MCPServerTable { server_name String? alias String? description String? + instructions String? url String? spec_path String? transport String @default("sse") @@ -1045,6 +1046,7 @@ model LiteLLM_HealthCheckTable { @@index([model_name]) @@index([checked_at]) @@index([status]) + @@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx") } // Search Tools table for storing search tool configurations diff --git a/scripts/health_check/benchmark_get_all_latest_health_checks.py b/scripts/health_check/benchmark_get_all_latest_health_checks.py new file mode 100644 index 0000000000..9161861883 --- /dev/null +++ b/scripts/health_check/benchmark_get_all_latest_health_checks.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +""" +Bench LiteLLM_HealthCheckTable + PrismaClient + - set DATABASE_URL to your Postgres + - Run ```prisma generate``` to install prisma client before running test ) + - This test writes to the default "public" database. Make sure to run cleanup after testing + +""" + +from __future__ import annotations + +import argparse +import asyncio +import gc +import os +import sys +import time +import tracemalloc +from datetime import datetime, timedelta, timezone +from typing import Any, List + +SEED_MARKER = "benchmark_get_all_latest_health_checks.py" # Utility Marker for cleanup process. + + +def _rss_kb_linux() -> int: + try: + with open("/proc/self/status", encoding="utf-8") as f: + for line in f: + if line.startswith("VmRSS:"): + return int(line.split()[1]) + except OSError: + pass + return 0 + + +def _fmt_kb(kb: int) -> str: + if kb <= 0: + return "n/a" + return f"{kb} KiB (~{kb / 1024.0:.1f} MiB)" + + +def _build_batch( + *, + batch_index: int, + batch_size: int, + num_models: int, + base_time: datetime, +) -> List[dict[str, Any]]: + rows: List[dict[str, Any]] = [] + for i in range(batch_size): + global_i = batch_index * batch_size + i + model_idx = global_i % max(num_models, 1) + model_name = f"bench-model-{model_idx}" + model_id = f"bench-mid-{model_idx}" if model_idx % 2 == 0 else None + checked_at = base_time - timedelta(seconds=global_i) + rows.append( + { + "model_name": model_name, + "model_id": model_id, + "status": "healthy" if global_i % 3 else "unhealthy", + "healthy_count": 1, + "unhealthy_count": 0, + "checked_by": SEED_MARKER, + "checked_at": checked_at, + } + ) + return rows + + +async def _seed( + prisma: Any, + *, + total_rows: int, + batch_size: int, + num_models: int, +) -> None: + db = prisma.db + base_time = datetime.now(timezone.utc) + inserted = 0 + batch_idx = 0 + while inserted < total_rows: + n = min(batch_size, total_rows - inserted) + await db.litellm_healthchecktable.create_many( + data=_build_batch( + batch_index=batch_idx, + batch_size=n, + num_models=num_models, + base_time=base_time, + ) + ) + inserted += n + batch_idx += 1 + if batch_idx % 10 == 0: + print(f" {inserted}/{total_rows}", flush=True) + print(f"Seeded {inserted} rows ({SEED_MARKER}).") + + +async def _cleanup(prisma: Any) -> None: + result = await prisma.db.litellm_healthchecktable.delete_many( + where={"checked_by": SEED_MARKER}, + ) + n = getattr(result, "count", result) + print(f"Deleted {n} rows.") + + +async def _bench(prisma: Any) -> None: + gc.collect() + rss0 = _rss_kb_linux() + print(f"RSS (after gc): {_fmt_kb(rss0)}") + + tracemalloc.start() + t0 = time.perf_counter() + try: + rows = await prisma.get_all_latest_health_checks() + finally: + elapsed = time.perf_counter() - t0 + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + gc.collect() + rss1 = _rss_kb_linux() + print(f"get_all_latest_health_checks: {len(rows)} rows in {elapsed:.2f}s") + print(f"tracemalloc peak: {peak / 1e6:.2f} MiB") + print(f"RSS after: {_fmt_kb(rss1)}") + + +async def _amain() -> int: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("action", choices=("seed", "bench", "cleanup")) + p.add_argument("--rows", type=int, default=10_000) + p.add_argument("--batch-size", type=int, default=1000) + p.add_argument("--num-models", type=int, default=50) + args = p.parse_args() + + database_url = os.getenv("DATABASE_URL") + if not database_url: + print("Set DATABASE_URL.", file=sys.stderr) + return 1 + + repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + if repo_root not in sys.path: + sys.path.insert(0, repo_root) + + from litellm.caching.caching import DualCache + from litellm.proxy.proxy_cli import append_query_params + from litellm.proxy.utils import PrismaClient, ProxyLogging + + db_url = append_query_params( + database_url, {"connection_limit": 100, "pool_timeout": 60} + ) + prisma = PrismaClient( + database_url=db_url, + proxy_logging_obj=ProxyLogging(user_api_key_cache=DualCache()), + ) + try: + await prisma.connect() + except Exception as e: + print(f"Connect failed: {e}", file=sys.stderr) + return 1 + + try: + if args.action == "seed": + await _seed( + prisma, + total_rows=args.rows, + batch_size=args.batch_size, + num_models=args.num_models, + ) + elif args.action == "bench": + await _bench(prisma) + else: + await _cleanup(prisma) + finally: + try: + await prisma.disconnect() + except Exception: + pass + return 0 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(_amain())) diff --git a/tests/enterprise/litellm_enterprise/enterprise_callbacks/test_prometheus_logging_callbacks.py b/tests/enterprise/litellm_enterprise/enterprise_callbacks/test_prometheus_logging_callbacks.py index 834cb235f0..8fa56029f4 100644 --- a/tests/enterprise/litellm_enterprise/enterprise_callbacks/test_prometheus_logging_callbacks.py +++ b/tests/enterprise/litellm_enterprise/enterprise_callbacks/test_prometheus_logging_callbacks.py @@ -605,6 +605,7 @@ def test_increment_top_level_request_and_spend_metrics(prometheus_logger): org_alias=None, model="gpt-3.5-turbo", model_id="model-123", + api_provider="openai", client_ip=None, user_agent=None, ) @@ -623,6 +624,7 @@ def test_increment_top_level_request_and_spend_metrics(prometheus_logger): org_alias=None, model="gpt-3.5-turbo", model_id="model-123", + api_provider="openai", client_ip=None, user_agent=None, ) diff --git a/tests/guardrails_tests/test_bedrock_guardrails.py b/tests/guardrails_tests/test_bedrock_guardrails.py index cb594c221c..146d16c242 100644 --- a/tests/guardrails_tests/test_bedrock_guardrails.py +++ b/tests/guardrails_tests/test_bedrock_guardrails.py @@ -5,7 +5,10 @@ import pytest sys.path.insert(0, os.path.abspath("../..")) import litellm -from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import BedrockGuardrail +from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import ( + BedrockGuardrail, + _redact_pii_matches, +) from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache from unittest.mock import MagicMock, AsyncMock, patch @@ -1601,3 +1604,128 @@ async def test_bedrock_guardrail_post_call_success_hook_no_output_text(): # If no error is raised and result is None, then the test passes assert result is None print("✅ No output text in response test passed") + + +@pytest.mark.asyncio +async def test__redact_pii_matches_null_list_fields(): + """Test that explicit null values from Bedrock API are handled correctly. + + The Bedrock API can return explicit JSON null for list fields like + piiEntities, regexes, customWords, managedWordLists. This would cause + TypeError: 'NoneType' object is not iterable if not handled. + """ + # Test 1: null piiEntities and regexes + response_with_null_pii = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": None, + "regexes": None, + } + } + ], + } + redacted = _redact_pii_matches(response_with_null_pii) + assert redacted is not None + assert redacted["assessments"][0]["sensitiveInformationPolicy"]["piiEntities"] is None + assert redacted["assessments"][0]["sensitiveInformationPolicy"]["regexes"] is None + + # Test 2: null customWords and managedWordLists + response_with_null_words = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "wordPolicy": { + "customWords": None, + "managedWordLists": None, + } + } + ], + } + redacted = _redact_pii_matches(response_with_null_words) + assert redacted is not None + assert redacted["assessments"][0]["wordPolicy"]["customWords"] is None + assert redacted["assessments"][0]["wordPolicy"]["managedWordLists"] is None + + # Test 3: null assessments at top level + response_with_null_assessments = { + "action": "GUARDRAIL_INTERVENED", + "assessments": None, + } + redacted = _redact_pii_matches(response_with_null_assessments) + assert redacted is not None + + +@pytest.mark.asyncio +async def test__redact_pii_matches_malformed_response(): + """Test _redact_pii_matches with malformed response (should not crash)""" + + # Test with completely malformed response + malformed_response = { + "action": "GUARDRAIL_INTERVENED", + "assessments": "not_a_list", + } + redacted_response = _redact_pii_matches(malformed_response) + assert redacted_response == malformed_response + + # Test with missing keys + missing_keys_response = { + "action": "GUARDRAIL_INTERVENED", + } + redacted_response = _redact_pii_matches(missing_keys_response) + assert redacted_response == missing_keys_response + + +@pytest.mark.asyncio +async def test_should_raise_guardrail_blocked_exception_null_fields(): + """Test that _should_raise_guardrail_blocked_exception handles null list fields. + + Validates the or [] null-safety pattern works for all policy fields + in _should_raise_guardrail_blocked_exception. + """ + guardrail = BedrockGuardrail( + guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT" + ) + + # Test with null assessments + response_null_assessments = { + "action": "GUARDRAIL_INTERVENED", + "assessments": None, + } + assert guardrail._should_raise_guardrail_blocked_exception(response_null_assessments) is False + + # Test with null topics in topicPolicy + response_null_topics = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [{"topicPolicy": {"topics": None}}], + } + assert guardrail._should_raise_guardrail_blocked_exception(response_null_topics) is False + + # Test with null filters in contentPolicy + response_null_filters = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [{"contentPolicy": {"filters": None}}], + } + assert guardrail._should_raise_guardrail_blocked_exception(response_null_filters) is False + + # Test with null customWords and managedWordLists in wordPolicy + response_null_words = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [{"wordPolicy": {"customWords": None, "managedWordLists": None}}], + } + assert guardrail._should_raise_guardrail_blocked_exception(response_null_words) is False + + # Test with null piiEntities and regexes in sensitiveInformationPolicy + response_null_pii = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [{"sensitiveInformationPolicy": {"piiEntities": None, "regexes": None}}], + } + assert guardrail._should_raise_guardrail_blocked_exception(response_null_pii) is False + + # Test with null filters in contextualGroundingPolicy + response_null_grounding = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [{"contextualGroundingPolicy": {"filters": None}}], + } + assert guardrail._should_raise_guardrail_blocked_exception(response_null_grounding) is False diff --git a/tests/llm_translation/test_azure_openai.py b/tests/llm_translation/test_azure_openai.py index 6ee740b0f7..4f12e12700 100644 --- a/tests/llm_translation/test_azure_openai.py +++ b/tests/llm_translation/test_azure_openai.py @@ -5,6 +5,7 @@ sys.path.insert( 0, os.path.abspath("../../") ) # Adds the parent directory to the system path +import httpx import pytest from litellm.llms.azure.common_utils import process_azure_headers from httpx import Headers diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index b14b25f384..82a3d96b02 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -144,6 +144,45 @@ def test_get_optional_params_with_allowed_openai_params(): assert optional_params["reasoning_effort"] == reasoning_effort +def test_allowed_openai_params_does_not_forward_unset_params(): + """ + Regression test for https://github.com/BerriAI/litellm/issues/25697 + + When a user lists a param in ``allowed_openai_params`` but does not + actually send that param in the request, litellm must not forward it + to the provider SDK as ``None``. The openai SDK rejects unknown + top-level kwargs with + ``AsyncCompletions.create() got an unexpected keyword argument 'enable_thinking'``. + + Reproduces the reported config where the user listed both + ``chat_template_kwargs`` and ``enable_thinking`` in + ``allowed_openai_params`` and only sent ``chat_template_kwargs`` + (with ``enable_thinking`` nested inside it). Previously the loop + added ``optional_params["enable_thinking"] = None`` which then + crashed the openai client. + """ + from litellm.utils import _apply_openai_param_overrides + + chat_template_kwargs = {"enable_thinking": False} + optional_params: dict = {} + non_default_params = {"chat_template_kwargs": chat_template_kwargs} + + result = _apply_openai_param_overrides( + optional_params=optional_params, + non_default_params=non_default_params, + allowed_openai_params=["chat_template_kwargs", "enable_thinking"], + ) + + assert result["chat_template_kwargs"] == chat_template_kwargs + # enable_thinking was NOT sent as a top-level param — it must not be + # forwarded to the provider SDK (openai AsyncCompletions.create would + # reject an unknown kwarg, even if its value is None). + assert "enable_thinking" not in result + # And the only entry actually moved out of non_default_params is + # the one the caller sent. + assert "chat_template_kwargs" not in non_default_params + + def test_bedrock_optional_params_embeddings(): litellm.drop_params = True optional_params = get_optional_params_embeddings( diff --git a/tests/local_testing/test_azure_anthropic_sync_post.py b/tests/local_testing/test_azure_anthropic_sync_post.py new file mode 100644 index 0000000000..14558d3bff --- /dev/null +++ b/tests/local_testing/test_azure_anthropic_sync_post.py @@ -0,0 +1,46 @@ +""" +``_get_httpx_client`` + ``HTTPHandler.post`` (same pattern as Azure Anthropic sync path: +``_get_httpx_client(params={"timeout": ...})`` then ``post(..., timeout=...)``). + +Uses https://httpbin.org/delay/10 with ``timeout=5`` — the handler must raise :class:`~litellm.exceptions.Timeout` +before the 10s delay completes. Skips if httpbin is unreachable. + +Lives under ``local_testing`` (not ``make test-unit``). +""" + +import json +import os +import sys + +import httpx +import pytest + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +) + +from litellm.exceptions import Timeout as LitellmTimeout +from litellm.llms.custom_httpx.http_handler import _get_httpx_client + +_HTTPBIN_DELAY_S = 10 +_PER_REQUEST_TIMEOUT_S = 5.0 +_CLIENT_DEFAULT_TIMEOUT_S = 60.0 + + +def test_post_delay_exceeds_per_request_timeout_raises(): + try: + httpx.get("https://httpbin.org/get", timeout=5.0) + except Exception as e: + pytest.skip(f"httpbin.org unreachable: {e}") + + handler = _get_httpx_client(params={"timeout": _CLIENT_DEFAULT_TIMEOUT_S}) + try: + with pytest.raises(LitellmTimeout): + handler.post( + f"https://httpbin.org/delay/{_HTTPBIN_DELAY_S}", + headers={"content-type": "application/json"}, + data=json.dumps({"model": "claude", "messages": []}), + timeout=_PER_REQUEST_TIMEOUT_S, + ) + finally: + handler.close() diff --git a/tests/local_testing/test_unit_test_caching.py b/tests/local_testing/test_unit_test_caching.py index fa5cf80254..e25b75e658 100644 --- a/tests/local_testing/test_unit_test_caching.py +++ b/tests/local_testing/test_unit_test_caching.py @@ -131,6 +131,55 @@ def test_get_cache_key_text_completion(): assert cache_key_2 == cache_key_3 +def test_get_cache_key_responses_api(): + """ + Regression test: two /v1/responses calls that differ only in + `instructions` (or any Responses-API-only param) must produce + different cache keys. Mirrors the chat / embedding / text-completion + cache-key tests above. + """ + cache = Cache() + + base_kwargs = { + "model": "openai/gpt-4.1", + "input": [{"role": "user", "content": "what is the weather"}], + "temperature": 0.3, + } + + kwargs_a = { + **base_kwargs, + "instructions": "summarize the weather on 10th May", + } + kwargs_b = { + **base_kwargs, + "instructions": "summarize the weather on 7th May", + } + + key_a = cache.get_cache_key(**kwargs_a) + key_b = cache.get_cache_key(**kwargs_b) + + assert isinstance(key_a, str) and len(key_a) > 0 + assert key_a != key_b, "instructions must be part of the Responses API cache key" + + # Sanity: identical payloads must still collide (cache hits still work) + key_a_again = cache.get_cache_key(**kwargs_a) + assert key_a == key_a_again + + # Spot-check a handful of other Responses-only params individually. + for param, value_x, value_y in [ + ("previous_response_id", "resp_aaa", "resp_bbb"), + ("reasoning", {"effort": "low"}, {"effort": "high"}), + ("include", ["reasoning.encrypted_content"], []), + ("max_output_tokens", 100, 500), + ("background", True, False), + ]: + kx = {**base_kwargs, param: value_x} + ky = {**base_kwargs, param: value_y} + assert cache.get_cache_key(**kx) != cache.get_cache_key( + **ky + ), f"Responses-API param `{param}` is not part of the cache key" + + def test_get_hashed_cache_key(): cache = Cache() cache_key = "model:gpt-3.5-turbo,messages:Hello world" diff --git a/tests/logging_callback_tests/test_standard_logging_payload.py b/tests/logging_callback_tests/test_standard_logging_payload.py index 163e2d9435..4c3fcdc8fc 100644 --- a/tests/logging_callback_tests/test_standard_logging_payload.py +++ b/tests/logging_callback_tests/test_standard_logging_payload.py @@ -158,12 +158,15 @@ def test_get_additional_headers(): additional_logging_headers = StandardLoggingPayloadSetup.get_additional_headers( additional_headers ) - assert additional_logging_headers == { - "x_ratelimit_limit_requests": 2000, - "x_ratelimit_remaining_requests": 1999, - "x_ratelimit_limit_tokens": 160000, - "x_ratelimit_remaining_tokens": 160000, - } + # Typed rate-limit fields are coerced to int + assert additional_logging_headers is not None + assert additional_logging_headers.get("x_ratelimit_limit_requests") == 2000 + assert additional_logging_headers.get("x_ratelimit_remaining_requests") == 1999 + assert additional_logging_headers.get("x_ratelimit_limit_tokens") == 160000 + assert additional_logging_headers.get("x_ratelimit_remaining_tokens") == 160000 + # Provider-specific headers are preserved verbatim (not dropped) + assert additional_logging_headers.get("llm_provider-request-id") == "req_01F6CycZZPSHKRCCctcS1Vto" + assert additional_logging_headers.get("llm_provider-anthropic-ratelimit-requests-reset") == "2024-10-29T23:57:40Z" def all_fields_present(standard_logging_metadata: StandardLoggingMetadata): diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index a4a28215e1..6af0758579 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -417,17 +417,22 @@ async def test_streamable_http_mcp_handler_mock(): # Mock extract_mcp_auth_context to bypass auth checks in the handler mock_auth_context = (None, None, None, {}, {}, {}) - with patch( - "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", - True, - ), patch( - "litellm.proxy._experimental.mcp_server.server.session_manager", - mock_session_manager, - ), patch( - "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", - AsyncMock(return_value=mock_auth_context), - ), patch( - "litellm.proxy._experimental.mcp_server.server.set_auth_context", + with ( + patch( + "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", + True, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.session_manager", + mock_session_manager, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", + AsyncMock(return_value=mock_auth_context), + ), + patch( + "litellm.proxy._experimental.mcp_server.server.set_auth_context", + ), ): from litellm.proxy._experimental.mcp_server.server import ( handle_streamable_http_mcp, @@ -471,17 +476,22 @@ async def test_sse_mcp_handler_mock(): [], ) - with patch( - "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", - True, - ), patch( - "litellm.proxy._experimental.mcp_server.server.sse_session_manager", - mock_sse_session_manager, - ), patch( - "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", - new=AsyncMock(return_value=mock_auth_result), - ), patch( - "litellm.proxy._experimental.mcp_server.server.set_auth_context", + with ( + patch( + "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", + True, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.sse_session_manager", + mock_sse_session_manager, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", + new=AsyncMock(return_value=mock_auth_result), + ), + patch( + "litellm.proxy._experimental.mcp_server.server.set_auth_context", + ), ): from litellm.proxy._experimental.mcp_server.server import handle_sse_mcp @@ -833,7 +843,9 @@ async def test_get_tools_from_mcp_servers(): mock_manager.get_allowed_mcp_servers = AsyncMock( return_value=["server1_id", "server2_id"] ) - mock_manager.get_mcp_server_by_id = lambda server_id: mock_server_1 if server_id == "server1_id" else mock_server_2 + mock_manager.get_mcp_server_by_id = lambda server_id: ( + mock_server_1 if server_id == "server1_id" else mock_server_2 + ) mock_manager._get_tools_from_server = AsyncMock(return_value=[mock_tool_1]) # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) mock_manager.filter_server_ids_by_ip_with_info = MagicMock( @@ -859,7 +871,10 @@ async def test_get_tools_from_mcp_servers(): mock_manager_2.get_allowed_mcp_servers = AsyncMock( return_value=["server1_id", "server2_id"] ) - mock_manager_2.get_mcp_server_by_id = lambda server_id: mock_server_1 if server_id == "server1_id" else mock_server_2 + mock_manager_2.get_mcp_server_by_id = lambda server_id: ( + mock_server_1 if server_id == "server1_id" else mock_server_2 + ) + async def mock_get_tools_side_effect( server, mcp_auth_header=None, @@ -900,7 +915,11 @@ async def test_get_tools_from_mcp_servers(): mock_manager.get_allowed_mcp_servers = AsyncMock( return_value=["server1_id", "server2_id", "server3_id"] ) - mock_manager.get_mcp_server_by_id = lambda server_id: mock_server_1 if server_id == "server1_id" else (mock_server_2 if server_id == "server2_id" else mock_server_3) + mock_manager.get_mcp_server_by_id = lambda server_id: ( + mock_server_1 + if server_id == "server1_id" + else (mock_server_2 if server_id == "server2_id" else mock_server_3) + ) mock_manager._get_tools_from_server = AsyncMock(return_value=[mock_tool_1]) # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) mock_manager.filter_server_ids_by_ip_with_info = MagicMock( @@ -1050,15 +1069,15 @@ async def test_mcp_server_manager_access_groups_from_config(): # Should find config_server for group-a, both for group-b, other_server for group-c import asyncio - server_ids_a = await MCPRequestHandler._get_mcp_servers_from_access_groups([ - "group-a" - ]) - server_ids_b = await MCPRequestHandler._get_mcp_servers_from_access_groups([ - "group-b" - ]) - server_ids_c = await MCPRequestHandler._get_mcp_servers_from_access_groups([ - "group-c" - ]) + server_ids_a = await MCPRequestHandler._get_mcp_servers_from_access_groups( + ["group-a"] + ) + server_ids_b = await MCPRequestHandler._get_mcp_servers_from_access_groups( + ["group-b"] + ) + server_ids_c = await MCPRequestHandler._get_mcp_servers_from_access_groups( + ["group-c"] + ) assert any(config_server.server_id == sid for sid in server_ids_a) assert set(server_ids_b) == set( [ @@ -1474,6 +1493,7 @@ async def test_add_update_server_with_alias(): mock_mcp_server.byok_api_key_help_url = None mock_mcp_server.created_at = None mock_mcp_server.updated_at = None + mock_mcp_server.instructions = None # Add server to manager await test_manager.add_server(mock_mcp_server) @@ -1530,6 +1550,7 @@ async def test_add_update_server_without_alias(): mock_mcp_server.byok_api_key_help_url = None mock_mcp_server.created_at = None mock_mcp_server.updated_at = None + mock_mcp_server.instructions = None # Add server to manager await test_manager.add_server(mock_mcp_server) @@ -1587,7 +1608,7 @@ async def test_add_update_server_fallback_to_server_id(): mock_mcp_server.byok_api_key_help_url = None mock_mcp_server.created_at = None mock_mcp_server.updated_at = None - + mock_mcp_server.instructions = None # Add server to manager await test_manager.add_server(mock_mcp_server) @@ -2151,8 +2172,12 @@ async def test_list_tool_rest_api_all_servers_with_auth(): for call_args in mock_get_tools.call_args_list } - assert server_auth_map.get(mock_zapier_server) == "Bearer zapier_token" - assert server_auth_map.get(mock_slack_server) == "Bearer slack_token" + assert ( + server_auth_map.get(mock_zapier_server) == "Bearer zapier_token" + ) + assert ( + server_auth_map.get(mock_slack_server) == "Bearer slack_token" + ) @pytest.mark.asyncio @@ -2690,26 +2715,33 @@ async def test_call_mcp_tool_uses_manager_permission_lookup(): expected_response = [TextContent(type="text", text="ok")] - with patch.object( - global_mcp_server_manager, - "get_allowed_mcp_servers", - new_callable=AsyncMock, - ) as mock_get_allowed, patch.object( - global_mcp_server_manager, - "get_mcp_server_by_id", - return_value=mock_server, - ), patch.object( - global_mcp_server_manager, - "_get_mcp_server_from_tool_name", - return_value=mock_server, - ) as mock_get_server, patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_tool_registry" - ) as mock_tool_registry, patch( - "litellm.proxy._experimental.mcp_server.server._handle_managed_mcp_tool", - new_callable=AsyncMock, - ) as mock_handle_managed, patch( - "litellm.proxy._experimental.mcp_server.server.MCPRequestHandler.is_tool_allowed", - return_value=True, + with ( + patch.object( + global_mcp_server_manager, + "get_allowed_mcp_servers", + new_callable=AsyncMock, + ) as mock_get_allowed, + patch.object( + global_mcp_server_manager, + "get_mcp_server_by_id", + return_value=mock_server, + ), + patch.object( + global_mcp_server_manager, + "_get_mcp_server_from_tool_name", + return_value=mock_server, + ) as mock_get_server, + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_tool_registry" + ) as mock_tool_registry, + patch( + "litellm.proxy._experimental.mcp_server.server._handle_managed_mcp_tool", + new_callable=AsyncMock, + ) as mock_handle_managed, + patch( + "litellm.proxy._experimental.mcp_server.server.MCPRequestHandler.is_tool_allowed", + return_value=True, + ), ): mock_get_allowed.return_value = [mock_server.server_id] mock_tool_registry.get_tool.return_value = None @@ -2759,27 +2791,34 @@ async def test_call_mcp_tool_resolves_unprefixed_tool_name_and_checks_permission expected_response = [TextContent(type="text", text="ok")] - with patch.object( - global_mcp_server_manager, - "get_allowed_mcp_servers", - new_callable=AsyncMock, - ) as mock_get_allowed, patch.object( - global_mcp_server_manager, - "get_mcp_server_by_id", - return_value=mock_server, - ), patch.object( - global_mcp_server_manager, - "_get_mcp_server_from_tool_name", - return_value=mock_server, - ) as mock_get_server, patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_tool_registry" - ) as mock_tool_registry, patch( - "litellm.proxy._experimental.mcp_server.server._handle_managed_mcp_tool", - new_callable=AsyncMock, - ) as mock_handle_managed, patch( - "litellm.proxy._experimental.mcp_server.server.MCPRequestHandler.is_tool_allowed", - return_value=True, - ) as mock_is_allowed: + with ( + patch.object( + global_mcp_server_manager, + "get_allowed_mcp_servers", + new_callable=AsyncMock, + ) as mock_get_allowed, + patch.object( + global_mcp_server_manager, + "get_mcp_server_by_id", + return_value=mock_server, + ), + patch.object( + global_mcp_server_manager, + "_get_mcp_server_from_tool_name", + return_value=mock_server, + ) as mock_get_server, + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_tool_registry" + ) as mock_tool_registry, + patch( + "litellm.proxy._experimental.mcp_server.server._handle_managed_mcp_tool", + new_callable=AsyncMock, + ) as mock_handle_managed, + patch( + "litellm.proxy._experimental.mcp_server.server.MCPRequestHandler.is_tool_allowed", + return_value=True, + ) as mock_is_allowed, + ): mock_get_allowed.return_value = [mock_server.server_id] mock_tool_registry.get_tool.return_value = None mock_handle_managed.return_value = expected_response diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 09f6a85938..9f5f14457e 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -10,7 +10,7 @@ import pytest from fastapi import Request from starlette.datastructures import State -from litellm.proxy.utils import _get_docs_url, _get_redoc_url +from litellm.proxy.utils import _get_docs_url, _get_openapi_url, _get_redoc_url sys.path.insert( 0, os.path.abspath("../..") @@ -735,6 +735,30 @@ def test_get_docs_url(env_vars, expected_url): result = _get_docs_url() assert result == expected_url +@pytest.mark.parametrize( + "env_vars, expected_url", + [ + ({}, "/openapi.json"), # default case + ({"OPENAPI_URL": "/custom-openapi.json"}, "/custom-openapi.json"), # custom URL + ( + {"OPENAPI_URL": "https://example.com/openapi.json"}, + "https://example.com/openapi.json", + ), # full URL + ({"NO_OPENAPI": "True"}, None), # openapi disabled + ], +) +def test_get_openapi_url(env_vars, expected_url): + # Clear relevant environment variables + for key in ["OPENAPI_URL", "NO_OPENAPI"]: + os.environ.pop(key, None) + + # Set test environment variables + for key, value in env_vars.items(): + os.environ[key] = value + + result = _get_openapi_url() + assert result == expected_url + @pytest.mark.parametrize( "request_tags, tags_to_add, expected_tags", diff --git a/tests/test_litellm/experimental_mcp_client/test_mcp_client.py b/tests/test_litellm/experimental_mcp_client/test_mcp_client.py index 13a09f54e6..dee689708c 100644 --- a/tests/test_litellm/experimental_mcp_client/test_mcp_client.py +++ b/tests/test_litellm/experimental_mcp_client/test_mcp_client.py @@ -40,6 +40,7 @@ class TestMCPClient: with pytest.raises( ValueError, match="stdio_config is required for stdio transport" ): + async def _noop(session): return None @@ -251,11 +252,11 @@ class TestMCPClient: server_url="http://example.com/sse", transport_type="sse", auth_type=MCPAuth.token, - auth_value="my-secret-token" + auth_value="my-secret-token", ) - + headers = client._get_auth_headers() - + assert "Authorization" in headers assert headers["Authorization"] == "token my-secret-token" @@ -266,27 +267,27 @@ class TestMCPClient: server_url="http://example.com/sse", transport_type="sse", auth_type=MCPAuth.bearer_token, - auth_value="bearer-token" + auth_value="bearer-token", ) headers = client._get_auth_headers() assert headers["Authorization"] == "Bearer bearer-token" - + # Test API key client = MCPClient( server_url="http://example.com/sse", transport_type="sse", auth_type=MCPAuth.api_key, - auth_value="api-key" + auth_value="api-key", ) headers = client._get_auth_headers() assert headers["X-API-Key"] == "api-key" - + # Test basic auth (gets base64 encoded) client = MCPClient( server_url="http://example.com/sse", transport_type="sse", auth_type=MCPAuth.basic, - auth_value="user:pass" + auth_value="user:pass", ) headers = client._get_auth_headers() assert headers["Authorization"].startswith("Basic ") @@ -298,11 +299,11 @@ class TestMCPClient: transport_type="sse", auth_type=MCPAuth.token, auth_value="my-token", - extra_headers={"X-Custom-Header": "custom-value"} + extra_headers={"X-Custom-Header": "custom-value"}, ) - + headers = client._get_auth_headers() - + assert headers["Authorization"] == "token my-token" assert headers["X-Custom-Header"] == "custom-value" @@ -312,5 +313,80 @@ class TestMCPClient: assert MCPAuth.token.value == "token" +# --------------------------------------------------------------------------- +# _last_initialize_instructions capture +# --------------------------------------------------------------------------- + + +class TestMCPClientInstructionsCapture: + """Tests for _last_initialize_instructions capture during session init.""" + + def test_initial_value_is_none(self): + """Fresh client has no cached instructions.""" + client = MCPClient( + server_url="http://example.com/mcp", + transport_type="http", + ) + assert client._last_initialize_instructions is None + + @pytest.mark.asyncio + @patch("litellm.experimental_mcp_client.client.ClientSession") + async def test_captures_instructions_from_initialize(self, mock_session_cls): + """Instructions from upstream initialize() are captured and stripped.""" + client = MCPClient( + server_url="http://example.com/mcp", + transport_type="http", + ) + + mock_session = AsyncMock() + init_result = MagicMock() + init_result.instructions = " upstream says hello " + mock_session.initialize = AsyncMock(return_value=init_result) + + session_ctx = MagicMock() + session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + session_ctx.__aexit__ = AsyncMock(return_value=False) + mock_session_cls.return_value = session_ctx + + transport_ctx = MagicMock() + transport_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock())) + transport_ctx.__aexit__ = AsyncMock(return_value=False) + + async def _op(session): + return "done" + + await client._execute_session_operation(transport_ctx, _op) + assert client._last_initialize_instructions == "upstream says hello" + + @pytest.mark.asyncio + @patch("litellm.experimental_mcp_client.client.ClientSession") + async def test_none_instructions_stays_none(self, mock_session_cls): + """When upstream returns no instructions the field stays None.""" + client = MCPClient( + server_url="http://example.com/mcp", + transport_type="http", + ) + + mock_session = AsyncMock() + init_result = MagicMock() + init_result.instructions = None + mock_session.initialize = AsyncMock(return_value=init_result) + + session_ctx = MagicMock() + session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + session_ctx.__aexit__ = AsyncMock(return_value=False) + mock_session_cls.return_value = session_ctx + + transport_ctx = MagicMock() + transport_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock())) + transport_ctx.__aexit__ = AsyncMock(return_value=False) + + async def _op(session): + return "done" + + await client._execute_session_operation(transport_ctx, _op) + assert client._last_initialize_instructions is None + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_litellm/integrations/test_prometheus_labels.py b/tests/test_litellm/integrations/test_prometheus_labels.py index 2553eb0627..69127e15b8 100644 --- a/tests/test_litellm/integrations/test_prometheus_labels.py +++ b/tests/test_litellm/integrations/test_prometheus_labels.py @@ -69,6 +69,21 @@ def test_model_id_in_required_metrics(): print(f"✅ {metric_name} contains model_id label") +def test_api_provider_in_spend_and_requests_metrics(): + """ + Test that api_provider label is present in spend and requests metrics + so users can build spend-by-provider and request-count-by-provider dashboards. + """ + api_provider_label = UserAPIKeyLabelNames.API_PROVIDER.value + + for metric_name in ["litellm_spend_metric", "litellm_requests_metric"]: + labels = PrometheusMetricLabels.get_labels(metric_name) + assert ( + api_provider_label in labels + ), f"Metric {metric_name} should contain api_provider label" + print(f"✅ {metric_name} contains api_provider label") + + def test_user_email_label_exists(): """Test that the USER_EMAIL label is properly defined""" assert UserAPIKeyLabelNames.USER_EMAIL.value == "user_email" diff --git a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py index ddc44cb505..c3849e5869 100644 --- a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py +++ b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py @@ -11,8 +11,7 @@ sys.path.insert( import time from litellm.constants import SENTRY_DENYLIST, SENTRY_PII_DENYLIST -from litellm.litellm_core_utils.litellm_logging import \ - Logging as LitellmLogging +from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging from litellm.litellm_core_utils.litellm_logging import set_callbacks from litellm.types.utils import ModelResponse, TextCompletionResponse @@ -140,8 +139,7 @@ def test_sentry_environment(): def test_use_custom_pricing_for_model(): - from litellm.litellm_core_utils.litellm_logging import \ - use_custom_pricing_for_model + from litellm.litellm_core_utils.litellm_logging import use_custom_pricing_for_model litellm_params = { "custom_llm_provider": "azure", @@ -156,8 +154,7 @@ def test_use_custom_pricing_for_model_via_litellm_metadata(): Generic API call routes (/messages, /responses) store model_info under litellm_metadata, not metadata. Regression test for #23185. """ - from litellm.litellm_core_utils.litellm_logging import \ - use_custom_pricing_for_model + from litellm.litellm_core_utils.litellm_logging import use_custom_pricing_for_model litellm_params = { "litellm_metadata": { @@ -173,8 +170,7 @@ def test_use_custom_pricing_for_model_via_litellm_metadata(): def test_use_custom_pricing_not_detected_litellm_metadata_no_pricing(): """Should return False when litellm_metadata.model_info has no pricing keys.""" - from litellm.litellm_core_utils.litellm_logging import \ - use_custom_pricing_for_model + from litellm.litellm_core_utils.litellm_logging import use_custom_pricing_for_model litellm_params = { "litellm_metadata": { @@ -190,8 +186,7 @@ def test_response_cost_calculator_uses_router_model_id_from_litellm_metadata(): does not carry _hidden_params (e.g. ResponsesAPIResponse from /v1/responses streaming). Regression test for custom pricing on streaming responses.""" import litellm - from litellm.litellm_core_utils.litellm_logging import \ - Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.llms.openai import ResponsesAPIResponse custom_model_id = "gpt-5-custom-pricing" @@ -301,8 +296,9 @@ class TestGetRouterModelId: def test_returns_none_when_no_litellm_params(self): """Should return None when litellm_params is not set.""" - from litellm.litellm_core_utils.litellm_logging import \ - Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import ( + Logging as LiteLLMLoggingObj, + ) obj = LiteLLMLoggingObj( model="test", @@ -326,10 +322,12 @@ class TestAnthropicPassthroughCustomPricing: when the logging object carries custom pricing in model_info.""" from unittest.mock import patch - from litellm.litellm_core_utils.litellm_logging import \ - Logging as LiteLLMLoggingObj - from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import \ - AnthropicPassthroughLoggingHandler + from litellm.litellm_core_utils.litellm_logging import ( + Logging as LiteLLMLoggingObj, + ) + from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import ( + AnthropicPassthroughLoggingHandler, + ) logging_obj = LiteLLMLoggingObj( model="claude-sonnet-4-20250514", @@ -438,7 +436,10 @@ class TestUpdateFromKwargs: ) # kwargs metadata is preserved, caller metadata is merged in - assert logging_obj.litellm_params["metadata"] == {"from_kwargs": True, "from_caller": True} + assert logging_obj.litellm_params["metadata"] == { + "from_kwargs": True, + "from_caller": True, + } def test_kwargs_metadata_wins_over_caller_metadata_in_conflict(self, logging_obj): """kwargs metadata takes precedence; caller litellm_params metadata is merged without overwriting.""" @@ -446,7 +447,10 @@ class TestUpdateFromKwargs: logging_obj.update_from_kwargs( kwargs=kwargs, - litellm_params={"metadata": {"from_caller": True, "shared_key": "caller_value"}, "litellm_call_id": "x"}, + litellm_params={ + "metadata": {"from_caller": True, "shared_key": "caller_value"}, + "litellm_call_id": "x", + }, ) # kwargs metadata is preserved (shared_key keeps the kwargs value), caller-only keys are added @@ -458,8 +462,9 @@ class TestUpdateFromKwargs: def test_custom_pricing_detected_via_litellm_metadata(self, logging_obj): """Custom pricing in litellm_metadata.model_info should set custom_pricing flag.""" - from litellm.litellm_core_utils.litellm_logging import \ - use_custom_pricing_for_model + from litellm.litellm_core_utils.litellm_logging import ( + use_custom_pricing_for_model, + ) lm_meta = { "model_info": { @@ -518,8 +523,7 @@ async def test_datadog_logger_not_shadowed_by_llm_obs(monkeypatch): monkeypatch.setenv("DD_SITE", "us5.datadoghq.com") from litellm.integrations.datadog.datadog import DataDogLogger - from litellm.integrations.datadog.datadog_llm_obs import \ - DataDogLLMObsLogger + from litellm.integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger from litellm.litellm_core_utils import litellm_logging as logging_module logging_module._in_memory_loggers.clear() @@ -560,8 +564,7 @@ async def test_logfire_logger_accepts_env_vars_for_base_url(monkeypatch): ) # no trailing slash on purpose # Import after env vars are set (important if module-level caching exists) - from litellm.integrations.opentelemetry import \ - OpenTelemetry # logger class + from litellm.integrations.opentelemetry import OpenTelemetry # logger class from litellm.litellm_core_utils import litellm_logging as logging_module logging_module._in_memory_loggers.clear() @@ -890,8 +893,7 @@ def test_success_handler_runs_guardrail_logging_hook_when_enabled(logging_obj): def test_get_user_agent_tags(): - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup tags = StandardLoggingPayloadSetup._get_user_agent_tags( proxy_server_request={ @@ -906,8 +908,7 @@ def test_get_user_agent_tags(): def test_get_request_tags(): - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup tags = StandardLoggingPayloadSetup._get_request_tags( litellm_params={"metadata": {"tags": ["test-tag"]}}, @@ -934,8 +935,7 @@ def test_get_request_tags_from_metadata_and_litellm_metadata(): 4. No tags in either 5. None values for metadata/litellm_metadata """ - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup # Test case 1: Tags in metadata only tags = StandardLoggingPayloadSetup._get_request_tags( @@ -1016,8 +1016,7 @@ def test_get_request_tags_does_not_mutate_original_tags(): would cause User-Agent tags to be duplicated because the function was mutating the original tags list instead of creating a copy. """ - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup # Create metadata with original tags original_tags = ["custom-tag-1", "custom-tag-2"] @@ -1077,8 +1076,7 @@ def test_get_request_tags_does_not_mutate_original_tags(): def test_get_extra_header_tags(): """Test the _get_extra_header_tags method with various scenarios.""" import litellm - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup # Store original value to restore later original_extra_headers = getattr(litellm, "extra_spend_tag_headers", None) @@ -1299,17 +1297,17 @@ async def test_e2e_generate_cold_storage_object_key_successful(): from datetime import datetime, timezone from unittest.mock import patch - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup # Create test data start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc) response_id = "chatcmpl-test-12345" team_alias = "test-team" - with patch("litellm.cold_storage_custom_logger", return_value="s3"), patch( - "litellm.integrations.s3.get_s3_object_key" - ) as mock_get_s3_key: + with ( + patch("litellm.cold_storage_custom_logger", return_value="s3"), + patch("litellm.integrations.s3.get_s3_object_key") as mock_get_s3_key, + ): # Mock the S3 object key generation to return a predictable result mock_get_s3_key.return_value = ( "2025-01-15/time-10-30-45-123456_chatcmpl-test-12345.json" @@ -1342,8 +1340,7 @@ async def test_e2e_generate_cold_storage_object_key_with_custom_logger_s3_path() from datetime import datetime, timezone from unittest.mock import MagicMock, patch - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup # Create test data start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc) @@ -1353,11 +1350,13 @@ async def test_e2e_generate_cold_storage_object_key_with_custom_logger_s3_path() mock_custom_logger = MagicMock() mock_custom_logger.s3_path = "storage" - with patch("litellm.cold_storage_custom_logger", "s3_v2"), patch( - "litellm.logging_callback_manager.get_active_custom_logger_for_callback_name" - ) as mock_get_logger, patch( - "litellm.integrations.s3.get_s3_object_key" - ) as mock_get_s3_key: + with ( + patch("litellm.cold_storage_custom_logger", "s3_v2"), + patch( + "litellm.logging_callback_manager.get_active_custom_logger_for_callback_name" + ) as mock_get_logger, + patch("litellm.integrations.s3.get_s3_object_key") as mock_get_s3_key, + ): # Setup mocks mock_get_logger.return_value = mock_custom_logger mock_get_s3_key.return_value = ( @@ -1394,8 +1393,7 @@ async def test_e2e_generate_cold_storage_object_key_with_logger_no_s3_path(): from datetime import datetime, timezone from unittest.mock import MagicMock, patch - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup # Create test data start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc) @@ -1405,11 +1403,13 @@ async def test_e2e_generate_cold_storage_object_key_with_logger_no_s3_path(): mock_custom_logger = MagicMock() mock_custom_logger.s3_path = None # or could be missing attribute - with patch("litellm.cold_storage_custom_logger", "s3_v2"), patch( - "litellm.logging_callback_manager.get_active_custom_logger_for_callback_name" - ) as mock_get_logger, patch( - "litellm.integrations.s3.get_s3_object_key" - ) as mock_get_s3_key: + with ( + patch("litellm.cold_storage_custom_logger", "s3_v2"), + patch( + "litellm.logging_callback_manager.get_active_custom_logger_for_callback_name" + ) as mock_get_logger, + patch("litellm.integrations.s3.get_s3_object_key") as mock_get_s3_key, + ): # Setup mocks mock_get_logger.return_value = mock_custom_logger mock_get_s3_key.return_value = ( @@ -1442,8 +1442,7 @@ async def test_e2e_generate_cold_storage_object_key_not_configured(): from unittest.mock import patch import litellm - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup # Create test data start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc) @@ -1467,8 +1466,7 @@ def test_get_final_response_obj_with_empty_response_obj_and_list_init(): When response_obj is empty (falsy), the method should return init_response_obj if it's a list. """ - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup # Create test objects class TestObject1: @@ -1504,8 +1502,7 @@ def test_get_usage_as_dict(): """ Test get_usage_as_dict returns usage as plain dict from response_obj or combined_usage_object. """ - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup from litellm.types.utils import Usage # Test case 1: None response_obj returns empty usage dict @@ -1543,8 +1540,7 @@ def test_append_system_prompt_messages(): """ Test append_system_prompt_messages prepends system message from kwargs to messages list. """ - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup # Test case 1: system in kwargs with existing messages kwargs = {"system": "You are a helpful assistant"} @@ -1615,8 +1611,7 @@ async def test_async_success_handler_sets_standard_logging_object_for_pass_throu from datetime import datetime from unittest.mock import patch - from litellm.litellm_core_utils.litellm_logging import \ - Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.utils import StandardPassThroughResponseObject # Create a logging object for a pass-through endpoint @@ -1697,8 +1692,7 @@ async def test_async_success_handler_prevents_reprocessing_for_pass_through_endp from datetime import datetime from unittest.mock import patch - from litellm.litellm_core_utils.litellm_logging import \ - Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.utils import StandardPassThroughResponseObject # Create a logging object for a pass-through endpoint @@ -1774,8 +1768,7 @@ async def test_async_success_handler_sets_standard_logging_object_for_streaming_ from datetime import datetime from unittest.mock import patch - from litellm.litellm_core_utils.litellm_logging import \ - Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.utils import StandardPassThroughResponseObject # Create a logging object for a streaming pass-through endpoint @@ -1831,8 +1824,7 @@ def test_get_error_information_error_code_priority(): Test get_error_information prioritizes 'code' attribute over 'status_code' attribute and handles edge cases like empty strings and "None" string values. """ - from litellm.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup # Test case 1: Exception with 'code' attribute (ProxyException style) class ProxyException(Exception): @@ -2025,8 +2017,7 @@ async def test_async_success_handler_preserves_response_cost_for_pass_through_en by pass-through handlers (Gemini/Vertex).""" from datetime import datetime - from litellm.litellm_core_utils.litellm_logging import \ - Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.utils import ModelResponse, Usage logging_obj = LiteLLMLoggingObj( @@ -2366,6 +2357,56 @@ def test_merge_hidden_params_from_response_into_metadata_no_op_when_empty(): _hidden_params = {} logging_obj._merge_hidden_params_from_response_into_metadata(_NoHp()) - assert "hidden_params" not in logging_obj.model_call_details["litellm_params"][ - "metadata" - ] + assert ( + "hidden_params" + not in logging_obj.model_call_details["litellm_params"]["metadata"] + ) + + +# ── StandardLoggingPayloadSetup.get_additional_headers ─────────────────────── + + +def test_get_additional_headers_preserves_provider_request_id(): + """llm_provider-x-request-id must survive the get_additional_headers filter.""" + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + + raw = { + "x-ratelimit-remaining-requests": "29999", + "x-ratelimit-remaining-tokens": "149999970", + "llm_provider-x-request-id": "req_85f49b546c7b4d3180755621f36631a1", + "llm_provider-openai-organization": "my-org", + "llm_provider-openai-processing-ms": "649", + } + + result = StandardLoggingPayloadSetup.get_additional_headers(raw) + + assert result is not None + # well-known fields parsed as ints + assert result["x_ratelimit_remaining_requests"] == 29999 # type: ignore + assert result["x_ratelimit_remaining_tokens"] == 149999970 # type: ignore + # provider-specific headers must be preserved verbatim + assert result["llm_provider-x-request-id"] == "req_85f49b546c7b4d3180755621f36631a1" # type: ignore + assert result["llm_provider-openai-organization"] == "my-org" # type: ignore + assert result["llm_provider-openai-processing-ms"] == "649" # type: ignore + + +def test_get_additional_headers_returns_none_for_none_input(): + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + + assert StandardLoggingPayloadSetup.get_additional_headers(None) is None + + +def test_get_additional_headers_reset_fields_preserved(): + """x-ratelimit-reset-* fields (added to the TypedDict) must be captured.""" + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + + raw = { + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "100ms", + } + + result = StandardLoggingPayloadSetup.get_additional_headers(raw) + + assert result is not None + assert result["x_ratelimit_reset_requests"] == "1s" # type: ignore + assert result["x_ratelimit_reset_tokens"] == "100ms" # type: ignore diff --git a/tests/test_litellm/llms/anthropic/test_anthropic_common_utils.py b/tests/test_litellm/llms/anthropic/test_anthropic_common_utils.py index 22470b9354..2f57ce5d18 100644 --- a/tests/test_litellm/llms/anthropic/test_anthropic_common_utils.py +++ b/tests/test_litellm/llms/anthropic/test_anthropic_common_utils.py @@ -440,14 +440,18 @@ class TestProxyOAuthHeaderForwarding: (b"content-type", b"application/json"), ] ) - + # Should preserve OAuth even with flag=False - cleaned_without_flag = clean_headers(raw_headers, forward_llm_provider_auth_headers=False) + cleaned_without_flag = clean_headers( + raw_headers, forward_llm_provider_auth_headers=False + ) assert "authorization" in cleaned_without_flag assert cleaned_without_flag["authorization"] == f"Bearer {FAKE_OAUTH_TOKEN}" - + # Should also preserve OAuth with flag=True - cleaned_with_flag = clean_headers(raw_headers, forward_llm_provider_auth_headers=True) + cleaned_with_flag = clean_headers( + raw_headers, forward_llm_provider_auth_headers=True + ) assert "authorization" in cleaned_with_flag assert cleaned_with_flag["authorization"] == f"Bearer {FAKE_OAUTH_TOKEN}" @@ -867,8 +871,6 @@ class TestValidateEnvironmentAuthToken: assert "authorization" not in headers - - class TestGetAuthToken: """Tests for AnthropicModelInfo.get_auth_token() static method.""" @@ -1092,7 +1094,10 @@ class TestPassthroughAuthToken: config = AnthropicMessagesConfig() with mock_patch.dict( "os.environ", - {"ANTHROPIC_API_KEY": FAKE_REGULAR_KEY, "ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN}, + { + "ANTHROPIC_API_KEY": FAKE_REGULAR_KEY, + "ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN, + }, clear=True, ): updated_headers, _ = config.validate_anthropic_messages_environment( @@ -1131,3 +1136,147 @@ class TestPassthroughAuthToken: ) assert url == "https://custom.example.com/v1/messages" + + +class TestAnthropicThinkingSignatureSelfHeal: + """Helpers for retrying after invalid encrypted thinking signatures.""" + + def test_is_anthropic_invalid_thinking_signature_error_positive(self): + from litellm.llms.anthropic.common_utils import ( + is_anthropic_invalid_thinking_signature_error, + ) + + raw = ( + '{"type":"error","error":{"type":"invalid_request_error",' + '"message":"messages.3.content.3: Invalid `signature` in `thinking` block"},' + '"request_id":"req_011Ca2EtQDxp7x6RGUY2jVn9"}' + ) + assert is_anthropic_invalid_thinking_signature_error(raw) is True + + def test_is_anthropic_invalid_thinking_signature_error_negative(self): + from litellm.llms.anthropic.common_utils import ( + is_anthropic_invalid_thinking_signature_error, + ) + + assert is_anthropic_invalid_thinking_signature_error("") is False + assert ( + is_anthropic_invalid_thinking_signature_error("rate limit exceeded") + is False + ) + + def test_strip_thinking_blocks_from_anthropic_messages(self): + from litellm.llms.anthropic.common_utils import ( + strip_thinking_blocks_from_anthropic_messages, + ) + + messages = [ + {"role": "user", "content": "hi"}, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "plan", "signature": "sig"}, + {"type": "text", "text": "hello"}, + ], + }, + ] + out = strip_thinking_blocks_from_anthropic_messages(messages) + assert len(out) == 2 + assert out[0] == messages[0] + assert len(out[1]["content"]) == 1 + assert out[1]["content"][0]["type"] == "text" + assert messages[1]["content"][0]["type"] == "thinking" + + def test_strip_thinking_blocks_drops_message_when_only_thinking_blocks(self): + from litellm.llms.anthropic.common_utils import ( + strip_thinking_blocks_from_anthropic_messages, + ) + + messages = [ + {"role": "user", "content": "hi"}, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "plan", "signature": "sig"}, + ], + }, + ] + out = strip_thinking_blocks_from_anthropic_messages(messages) + assert len(out) == 1 + assert out[0]["role"] == "user" + + def test_strip_thinking_blocks_from_anthropic_messages_request_dict(self): + from litellm.llms.anthropic.common_utils import ( + strip_thinking_blocks_from_anthropic_messages_request_dict, + ) + + data = { + "model": "claude-sonnet-4-20250514", + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": "x", + "signature": "y", + }, + ], + } + ], + "thinking": {"type": "enabled", "budget_tokens": 1024}, + } + strip_thinking_blocks_from_anthropic_messages_request_dict(data) + assert "thinking" not in data + assert data["messages"] == [] + + def test_anthropic_messages_config_http_retry_helpers(self): + import httpx + + from litellm.llms.anthropic.experimental_pass_through.messages.transformation import ( + AnthropicMessagesConfig, + ) + + config = AnthropicMessagesConfig() + assert config.max_retry_on_anthropic_messages_http_error == 2 + + req = httpx.Request("POST", "https://api.anthropic.com/v1/messages") + err_text = ( + '{"type":"error","error":{"type":"invalid_request_error",' + '"message":"messages.3.content.3: Invalid `signature` in `thinking` block"},' + '"request_id":"req_011Ca2EtQDxp7x6RGUY2jVn9"}' + ) + resp = httpx.Response(400, request=req, text=err_text) + err = httpx.HTTPStatusError("bad", request=req, response=resp) + assert config.should_retry_anthropic_messages_on_http_error(err, {}) is True + + resp_bad = httpx.Response(400, request=req, text="rate limit exceeded") + err_bad = httpx.HTTPStatusError("bad", request=req, response=resp_bad) + assert ( + config.should_retry_anthropic_messages_on_http_error(err_bad, {}) is False + ) + + resp_500 = httpx.Response(500, request=req, text=err_text) + err_500 = httpx.HTTPStatusError("bad", request=req, response=resp_500) + assert ( + config.should_retry_anthropic_messages_on_http_error(err_500, {}) is False + ) + + data = { + "model": "claude-sonnet-4-20250514", + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": "x", + "signature": "y", + }, + ], + } + ], + "thinking": {"type": "enabled", "budget_tokens": 1024}, + } + config.transform_anthropic_messages_request_on_http_error(err, data) + assert "thinking" not in data + assert data["messages"] == [] diff --git a/tests/test_litellm/llms/azure/passthrough/test_azure_passthrough_transformation.py b/tests/test_litellm/llms/azure/passthrough/test_azure_passthrough_transformation.py new file mode 100644 index 0000000000..529a7453d7 --- /dev/null +++ b/tests/test_litellm/llms/azure/passthrough/test_azure_passthrough_transformation.py @@ -0,0 +1,97 @@ +import json +import os +import sys +from unittest.mock import MagicMock + +import httpx + +sys.path.insert(0, os.path.abspath("../../../../..")) + +from litellm.llms.azure.passthrough.transformation import AzurePassthroughConfig +from litellm.types.utils import ModelResponse + + +def _azure_chat_completion_body(): + return { + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1700000000, + "model": "gpt-4.1-mini-2025-04-14", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 8, + "total_tokens": 18, + }, + } + + +def _make_httpx_response(body: dict) -> httpx.Response: + return httpx.Response( + status_code=200, + headers={"content-type": "application/json"}, + content=json.dumps(body).encode("utf-8"), + request=httpx.Request( + "POST", + "https://example.openai.azure.com/openai/deployments/gpt-4.1-mini/chat/completions", + ), + ) + + +def test_azure_passthrough_logging_non_streaming_response_chat_completions(): + """ + Returns a populated ModelResponse (with usage + content) for a chat/completions + endpoint. This is what _success_handler_helper_fn needs to build + standard_logging_object — without it, Datadog/cost-tracking/router-success all + raise on every Azure passthrough request. + """ + config = AzurePassthroughConfig() + logging_obj = MagicMock() + + result = config.logging_non_streaming_response( + model="gpt-4.1-mini", + custom_llm_provider="azure", + httpx_response=_make_httpx_response(_azure_chat_completion_body()), + request_data={ + "model": "gpt-4.1-mini", + "messages": [{"role": "user", "content": "hi"}], + }, + logging_obj=logging_obj, + endpoint="openai/deployments/gpt-4.1-mini/chat/completions", + ) + + assert isinstance(result, ModelResponse) + assert result.choices[0].message.content == "Hello! How can I assist you today?" + assert result.usage.prompt_tokens == 10 + assert result.usage.completion_tokens == 8 + assert result.usage.total_tokens == 18 + + +def test_azure_passthrough_logging_non_streaming_response_unknown_endpoint_returns_none(): + """ + Endpoints other than chat/completions (responses, messages, images) fall + through to None — matches base-class behavior and Bedrock's "unknown + endpoint" handling. Not a regression; just scoping. + """ + config = AzurePassthroughConfig() + logging_obj = MagicMock() + + result = config.logging_non_streaming_response( + model="gpt-4.1-mini", + custom_llm_provider="azure", + httpx_response=_make_httpx_response(_azure_chat_completion_body()), + request_data={}, + logging_obj=logging_obj, + endpoint="openai/responses", + ) + + assert result is None diff --git a/tests/test_litellm/llms/azure_ai/claude/test_azure_anthropic_handler.py b/tests/test_litellm/llms/azure_ai/claude/test_azure_anthropic_handler.py index ddfad420d0..ad86077224 100644 --- a/tests/test_litellm/llms/azure_ai/claude/test_azure_anthropic_handler.py +++ b/tests/test_litellm/llms/azure_ai/claude/test_azure_anthropic_handler.py @@ -222,5 +222,7 @@ class TestAzureAnthropicChatCompletion: # Verify non-streaming was handled mock_client.post.assert_called_once() + mock_get_client.assert_called_once_with(params={"timeout": timeout}) + assert mock_client.post.call_args.kwargs["timeout"] == timeout assert result is not None diff --git a/tests/test_litellm/llms/azure_ai/claude/test_main_azure_anthropic_timeout.py b/tests/test_litellm/llms/azure_ai/claude/test_main_azure_anthropic_timeout.py new file mode 100644 index 0000000000..a94034e510 --- /dev/null +++ b/tests/test_litellm/llms/azure_ai/claude/test_main_azure_anthropic_timeout.py @@ -0,0 +1,42 @@ +""" +Ensure litellm.completion() forwards timeout to Azure Anthropic handler (main.py dispatch). +""" + +import os +import sys +from unittest.mock import MagicMock, patch + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) +) + +from litellm import completion +from litellm.types.utils import ModelResponse + + +def test_main_azure_ai_claude_completion_passes_timeout_to_azure_anthropic_handler(): + captured: dict = {} + + def fake_azure_anthropic_completion(**kwargs): + captured.update(kwargs) + return ModelResponse() + + with patch( + "litellm.main.azure_anthropic_chat_completions" + ) as mock_azure_anthropic: + mock_azure_anthropic.completion = MagicMock( + side_effect=fake_azure_anthropic_completion + ) + + completion( + model="azure_ai/claude-sonnet-4-5", + messages=[{"role": "user", "content": "hi"}], + api_base="https://example.services.ai.azure.com/anthropic", + api_key="test-key", + timeout=42.5, + ) + + mock_azure_anthropic.completion.assert_called_once() + assert captured["timeout"] == 42.5 + assert captured["model"] == "claude-sonnet-4-5" + assert captured["custom_llm_provider"] == "azure_ai" diff --git a/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py b/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py index 7719f2bc8f..804d5997c4 100644 --- a/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py +++ b/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py @@ -3140,9 +3140,14 @@ def test_add_additional_properties_definitions(): assert result["definitions"]["Item"]["properties"]["details"]["additionalProperties"] is False -def test_json_object_no_schema_falls_back_to_tool_call(): - """response_format: {type: json_object} with no schema should use tool-call fallback, - even for models that support native structured outputs.""" +def test_json_object_no_schema_skips_tool_injection(): + """response_format: {type: json_object} with no schema should NOT inject + the synthetic json_tool_call tool. + + When no schema is given, _create_json_tool_call_for_response_format builds + a tool with an empty schema (properties: {}). The model follows the schema + and returns {} instead of the requested JSON. Skipping tool injection lets + the model respond naturally with the JSON the caller asked for.""" old_env = os.environ.get("LITELLM_LOCAL_MODEL_COST_MAP") old_cost = litellm.model_cost os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" @@ -3162,8 +3167,9 @@ def test_json_object_no_schema_falls_back_to_tool_call(): # Should NOT use native outputConfig (no schema provided) assert "outputConfig" not in result - # Should use tool-call fallback - assert "tools" in result + # Should NOT inject tools - empty schema causes model to return {} + assert "tools" not in result + assert "tool_choice" not in result assert result["json_mode"] is True finally: litellm.model_cost = old_cost @@ -3655,3 +3661,140 @@ def test_cache_control_injection_tool_config_not_added_without_injection_point() tools = result["toolConfig"]["tools"] # No cachePoint should be appended assert all("cachePoint" not in tool for tool in tools) + + +def test_translate_response_format_json_schema_still_injects_tool(): + """ + response_format with an explicit json_schema should still use the + synthetic tool call approach (for models that don't support native + structured outputs). + """ + config = AmazonConverseConfig() + + response_format = { + "type": "json_schema", + "json_schema": { + "name": "FactResult", + "schema": { + "type": "object", + "properties": { + "facts": { + "type": "array", + "items": {"type": "string"}, + }, + }, + "required": ["facts"], + }, + }, + } + + optional_params: dict = {} + result = config._translate_response_format_param( + value=response_format, + model="anthropic.claude-3-haiku-20240307-v1:0", + optional_params=optional_params, + non_default_params={"response_format": response_format}, + is_thinking_enabled=False, + ) + + assert result["json_mode"] is True + assert "tools" in result + assert "tool_choice" in result + + +def test_transform_response_finish_reason_stop_when_json_mode_filters_all_tools(): + """ + When json_mode is True and _filter_json_mode_tools strips all synthetic + tool calls, finish_reason should be "stop", not "tool_calls". + + Bedrock returns stopReason="tool_use" for json_tool_call responses. + After filtering, the response is plain JSON content and should not look + like a pending tool invocation to callers. + """ + from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig + from litellm.types.utils import ModelResponse + + response_json = { + "metrics": {"latencyMs": 100}, + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_001", + "name": "json_tool_call", + "input": { + "facts": ["Bob is a software engineer"], + }, + } + } + ], + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 50, + "outputTokens": 20, + "totalTokens": 70, + "cacheReadInputTokenCount": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokenCount": 0, + "cacheWriteInputTokens": 0, + }, + } + + class MockResponse: + def json(self): + return response_json + + @property + def text(self): + return json.dumps(response_json) + + config = AmazonConverseConfig() + model_response = ModelResponse() + + # Simulate what happens when json_tool_call was injected for a + # json_schema request: optional_params has the synthetic tool + optional_params = { + "json_mode": True, + "tools": [ + { + "type": "function", + "function": { + "name": "json_tool_call", + "parameters": { + "type": "object", + "additionalProperties": True, + "properties": {}, + }, + }, + } + ], + } + + result = config._transform_response( + model="bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0", + response=MockResponse(), + model_response=model_response, + stream=False, + logging_obj=None, + optional_params=optional_params, + api_key=None, + data=None, + messages=[], + encoding=None, + ) + + # Content should have the JSON from the tool call arguments + content = result.choices[0].message.content + assert content is not None + parsed = json.loads(content) + assert parsed["facts"] == ["Bob is a software engineer"] + + # No tool_calls on the message + assert result.choices[0].message.tool_calls is None + + # finish_reason must be "stop", not "tool_calls" + assert result.choices[0].finish_reason == "stop" diff --git a/tests/test_litellm/llms/custom_httpx/test_http_handler.py b/tests/test_litellm/llms/custom_httpx/test_http_handler.py index 0a3f0fe5e6..5fad8a9908 100644 --- a/tests/test_litellm/llms/custom_httpx/test_http_handler.py +++ b/tests/test_litellm/llms/custom_httpx/test_http_handler.py @@ -15,7 +15,12 @@ sys.path.insert( ) # Adds the parent directory to the system path import litellm from litellm.llms.custom_httpx.aiohttp_transport import LiteLLMAiohttpTransport -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, get_ssl_configuration +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, + get_ssl_configuration, +) @pytest.mark.asyncio @@ -658,3 +663,26 @@ async def test_httpx_handler_uses_env_user_agent(monkeypatch): assert req.headers.get("User-Agent") == "Claude Code" finally: await handler.close() + + +def test_get_httpx_client_applies_float_timeout_without_mocking_handler(): + """ + Exercise real _get_httpx_client + HTTPHandler: params={'timeout': x} must reach httpx.Client(timeout=...). + Uses an uncommon timeout value to avoid colliding with other cached clients in-process. + """ + timeout = 3847.291 + handler = _get_httpx_client(params={"timeout": timeout}) + try: + assert isinstance(handler, HTTPHandler) + assert handler.client.timeout == httpx.Timeout(timeout) + finally: + handler.close() + + +def test_get_httpx_client_applies_httpx_timeout_object_without_mocking_handler(): + t = httpx.Timeout(40.0, connect=5.0) + handler = _get_httpx_client(params={"timeout": t}) + try: + assert handler.client.timeout == t + finally: + handler.close() diff --git a/tests/test_litellm/llms/ollama/test_ollama_chat_transformation.py b/tests/test_litellm/llms/ollama/test_ollama_chat_transformation.py index 02495106a8..a1afdd3a36 100644 --- a/tests/test_litellm/llms/ollama/test_ollama_chat_transformation.py +++ b/tests/test_litellm/llms/ollama/test_ollama_chat_transformation.py @@ -15,6 +15,12 @@ from litellm.llms.ollama.chat.transformation import OllamaChatConfig, OllamaChat from litellm.types.llms.openai import AllMessageValues from litellm.utils import get_optional_params +import json +from unittest.mock import MagicMock + +import litellm +from litellm.types.utils import Choices, Message, ModelResponse + class TestEvent(BaseModel): name: str @@ -427,12 +433,6 @@ class TestOllamaToolCalling: def test_finish_reason_stop_when_no_tool_calls(self): """Test that finish_reason remains 'stop' when no tool_calls present.""" - import json - from unittest.mock import MagicMock - - import litellm - from litellm.types.utils import Choices, Message, ModelResponse - config = OllamaChatConfig() # Simulated Ollama response without tool_calls @@ -476,6 +476,140 @@ class TestOllamaToolCalling: assert result.choices[0].message.tool_calls is None +class TestOllamaFinishReasonLength: + """Tests for done_reason 'length' → finish_reason 'length' mapping. + + Ollama returns done_reason='length' when a response is truncated by num_predict + (max_tokens). Previously finish_reason was hardcoded to 'stop', hiding truncation. + The Anthropic pass-through adapter then maps OpenAI 'length' → 'max_tokens'. + """ + + def test_finish_reason_length_non_streaming(self): + """Non-streaming: done_reason='length' must propagate as finish_reason='length'.""" + config = OllamaChatConfig() + + ollama_response = { + "model": "qwen3:2b", + "created_at": "2025-01-11T00:00:00.000000Z", + "message": { + "role": "assistant", + "content": "A neural network learns through", + }, + "done": True, + "done_reason": "length", + "prompt_eval_count": 20, + "eval_count": 20, + } + + mock_response = MagicMock() + mock_response.json.return_value = ollama_response + mock_response.text = json.dumps(ollama_response) + + mock_logging = MagicMock() + + model_response = ModelResponse() + model_response.choices = [Choices(message=Message(content=""), index=0)] + + result = config.transform_response( + model="qwen3:2b", + raw_response=mock_response, + model_response=model_response, + logging_obj=mock_logging, + request_data={}, + messages=[{"role": "user", "content": "Explain neural networks."}], + optional_params={}, + litellm_params={}, + encoding=None, + api_key=None, + json_mode=False, + ) + + assert result.choices[0].finish_reason == "length", ( + f"Expected 'length' when done_reason='length', got '{result.choices[0].finish_reason}'" + ) + + def test_finish_reason_stop_non_streaming(self): + """Non-streaming: done_reason='stop' (natural finish) must stay 'stop'.""" + config = OllamaChatConfig() + + ollama_response = { + "model": "qwen3:2b", + "created_at": "2025-01-11T00:00:00.000000Z", + "message": {"role": "assistant", "content": "2 + 2 = 4."}, + "done": True, + "done_reason": "stop", + "prompt_eval_count": 10, + "eval_count": 8, + } + + mock_response = MagicMock() + mock_response.json.return_value = ollama_response + mock_response.text = json.dumps(ollama_response) + + mock_logging = MagicMock() + + model_response = ModelResponse() + model_response.choices = [Choices(message=Message(content=""), index=0)] + + result = config.transform_response( + model="qwen3:2b", + raw_response=mock_response, + model_response=model_response, + logging_obj=mock_logging, + request_data={}, + messages=[{"role": "user", "content": "What is 2+2?"}], + optional_params={}, + litellm_params={}, + encoding=None, + api_key=None, + json_mode=False, + ) + + assert result.choices[0].finish_reason == "stop", ( + f"Expected 'stop' for natural finish, got '{result.choices[0].finish_reason}'" + ) + + def test_finish_reason_length_streaming(self): + """Streaming: done_reason='length' in final chunk must produce finish_reason='length'.""" + iterator = OllamaChatCompletionResponseIterator( + streaming_response=iter([]), + sync_stream=True, + ) + + done_chunk = { + "model": "qwen3:2b", + "message": {"role": "assistant", "content": "A neural network learns through"}, + "done": True, + "done_reason": "length", + } + + result = iterator.chunk_parser(done_chunk) + + assert result.choices[0].finish_reason == "length", ( + f"Expected 'length' when done_reason='length', got '{result.choices[0].finish_reason}'" + ) + + def test_finish_reason_stop_streaming(self): + """Streaming: done_reason='stop' in final chunk must produce finish_reason='stop'.""" + iterator = OllamaChatCompletionResponseIterator( + streaming_response=iter([]), + sync_stream=True, + ) + + done_chunk = { + "model": "qwen3:2b", + "message": {"role": "assistant", "content": "2 + 2 = 4."}, + "done": True, + "done_reason": "stop", + } + + result = iterator.chunk_parser(done_chunk) + + assert result.choices[0].finish_reason == "stop", ( + f"Expected 'stop' for natural finish, got '{result.choices[0].finish_reason}'" + ) + + class TestOllamaReasoningContentStreaming: """Test that reasoning_content is properly extracted from all thinking chunks.""" diff --git a/tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/test_count_tokens_location.py b/tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/test_count_tokens_location.py index 6487ea25f2..53aa07d8c5 100644 --- a/tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/test_count_tokens_location.py +++ b/tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/test_count_tokens_location.py @@ -162,3 +162,69 @@ class TestCountTokensLocationResolution: ) assert captured["vertex_location"] == "asia-southeast1" + + +class TestCountTokensVersionSuffixStripping: + """Verify that version suffixes (@default, @20251001, etc.) are stripped + from model names before sending to the Vertex AI count-tokens endpoint. + + The Vertex AI count-tokens API rejects versioned model names with: + "claude-sonnet-4-6@default is not supported for token counting" + while "claude-sonnet-4-6" (without suffix) works correctly. + """ + + def test_strip_version_suffix_at_default(self): + counter = VertexAIPartnerModelsTokenCounter() + assert counter._strip_version_suffix("claude-sonnet-4-6@default") == "claude-sonnet-4-6" + + def test_strip_version_suffix_at_date(self): + counter = VertexAIPartnerModelsTokenCounter() + assert counter._strip_version_suffix("claude-haiku-4-5@20251001") == "claude-haiku-4-5" + + def test_strip_version_suffix_no_suffix(self): + counter = VertexAIPartnerModelsTokenCounter() + assert counter._strip_version_suffix("claude-sonnet-4-6") == "claude-sonnet-4-6" + + @pytest.mark.asyncio + async def test_handle_count_tokens_strips_version_from_request_data(self, monkeypatch): + """The model name in request_data sent to the API must have @suffix stripped.""" + counter = VertexAIPartnerModelsTokenCounter() + captured_json = {} + + async def fake_ensure_access_token(self, credentials, project_id, custom_llm_provider): + return "fake-token", "fake-project" + + def fake_build_endpoint(self, model, project_id, vertex_location, api_base=None): + return "https://fake-endpoint" + + monkeypatch.setattr( + VertexAIPartnerModelsTokenCounter, "_ensure_access_token_async", fake_ensure_access_token + ) + monkeypatch.setattr( + VertexAIPartnerModelsTokenCounter, "_build_count_tokens_endpoint", fake_build_endpoint + ) + + class FakeResponse: + status_code = 200 + def json(self): + return {"input_tokens": 10} + + class FakeClient: + async def post(self, url, headers=None, json=None, **kwargs): + captured_json.update(json or {}) + return FakeResponse() + + import litellm.llms.vertex_ai.vertex_ai_partner_models.count_tokens.handler as handler_mod + monkeypatch.setattr(handler_mod, "get_async_httpx_client", lambda **kwargs: FakeClient()) + + await counter.handle_count_tokens_request( + model="claude-sonnet-4-6@default", + request_data={ + "model": "claude-sonnet-4-6@default", + "messages": [{"role": "user", "content": "hi"}], + }, + litellm_params={"vertex_location": "us-east5"}, + ) + + # The model name sent to the API must NOT have the @default suffix + assert captured_json["model"] == "claude-sonnet-4-6" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index 384d428888..9df6408b0d 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -5,7 +5,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import HTTPException from mcp import ReadResourceResult, Resource -from mcp.types import BlobResourceContents, Prompt, ResourceTemplate, TextResourceContents +from mcp.types import ( + BlobResourceContents, + Prompt, + ResourceTemplate, + TextResourceContents, +) from litellm.proxy._types import ( LiteLLM_MCPServerTable, @@ -157,15 +162,19 @@ async def test_get_prompts_from_mcp_servers_success(): server_b.auth_type = None server_b.extra_headers = None - with patch( - "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", - AsyncMock(return_value=[server_a, server_b]), - ) as mock_allowed, patch( - "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", - return_value=(None, None), - ) as mock_headers, patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", - ) as mock_manager: + with ( + patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", + AsyncMock(return_value=[server_a, server_b]), + ) as mock_allowed, + patch( + "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", + return_value=(None, None), + ) as mock_headers, + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + ) as mock_manager, + ): mock_manager.get_prompts_from_server = AsyncMock( side_effect=[ [Prompt(name="hello", description="hi")], @@ -213,15 +222,19 @@ async def test_get_resources_from_mcp_servers_success(): server_b.auth_type = None server_b.extra_headers = None - with patch( - "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", - AsyncMock(return_value=[server_a, server_b]), - ) as mock_allowed, patch( - "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", - return_value=(None, None), - ) as mock_headers, patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", - ) as mock_manager: + with ( + patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", + AsyncMock(return_value=[server_a, server_b]), + ) as mock_allowed, + patch( + "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", + return_value=(None, None), + ) as mock_headers, + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + ) as mock_manager, + ): mock_manager.get_resources_from_server = AsyncMock( side_effect=[ [ @@ -274,15 +287,19 @@ async def test_get_resource_templates_from_mcp_servers_success(): server.auth_type = None server.extra_headers = None - with patch( - "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", - AsyncMock(return_value=[server]), - ) as mock_allowed, patch( - "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", - return_value=(None, None), - ) as mock_headers, patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", - ) as mock_manager: + with ( + patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", + AsyncMock(return_value=[server]), + ) as mock_allowed, + patch( + "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", + return_value=(None, None), + ) as mock_headers, + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + ) as mock_manager, + ): mock_manager.get_resource_templates_from_server = AsyncMock( return_value=[ ResourceTemplate( @@ -320,15 +337,19 @@ async def test_mcp_get_prompt_success(): prompt_result = MagicMock(name="prompt_result") - with patch( - "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", - AsyncMock(return_value=[server]), - ) as mock_allowed, patch( - "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", - return_value=({"Authorization": "token"}, {"X-Test": "1"}), - ) as mock_headers, patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", - ) as mock_manager: + with ( + patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", + AsyncMock(return_value=[server]), + ) as mock_allowed, + patch( + "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", + return_value=({"Authorization": "token"}, {"X-Test": "1"}), + ) as mock_headers, + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + ) as mock_manager, + ): mock_manager.get_prompt_from_server = AsyncMock(return_value=prompt_result) result = await mcp_get_prompt( @@ -378,15 +399,19 @@ async def test_mcp_read_resource_success(): ] ) - with patch( - "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", - AsyncMock(return_value=[server]), - ) as mock_allowed, patch( - "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", - return_value=({"Authorization": "token"}, {"X-Test": "1"}), - ) as mock_headers, patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", - ) as mock_manager: + with ( + patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", + AsyncMock(return_value=[server]), + ) as mock_allowed, + patch( + "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", + return_value=({"Authorization": "token"}, {"X-Test": "1"}), + ) as mock_headers, + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + ) as mock_manager, + ): mock_manager.read_resource_from_server = AsyncMock(return_value=read_result) result = await mcp_read_resource( @@ -591,7 +616,10 @@ async def test_get_tools_from_mcp_servers_continues_when_one_server_fails(): working_server if server_id == "working_server" else failing_server ) # Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering) - mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0) + mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: ( + server_ids, + 0, + ) async def mock_get_tools_from_server( server, @@ -693,7 +721,10 @@ async def test_get_tools_from_mcp_servers_handles_all_servers_failing(): failing_server1 if server_id == "failing_server1" else failing_server2 ) # Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering) - mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0) + mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: ( + server_ids, + 0, + ) async def mock_get_tools_from_server( server, @@ -830,12 +861,14 @@ async def test_concurrent_initialize_session_managers(): mcp_server._sse_session_manager_cm = None # Mock the session managers to avoid actual MCP initialization - with patch( - "litellm.proxy._experimental.mcp_server.server.session_manager" - ) as mock_session_manager, patch( - "litellm.proxy._experimental.mcp_server.server.sse_session_manager" - ) as mock_sse_session_manager, patch( - "litellm.proxy._experimental.mcp_server.server.verbose_logger" + with ( + patch( + "litellm.proxy._experimental.mcp_server.server.session_manager" + ) as mock_session_manager, + patch( + "litellm.proxy._experimental.mcp_server.server.sse_session_manager" + ) as mock_sse_session_manager, + patch("litellm.proxy._experimental.mcp_server.server.verbose_logger"), ): # Mock the run() method to return a mock context manager mock_cm = AsyncMock() @@ -961,15 +994,19 @@ async def test_mcp_routing_with_conflicting_alias_and_group_name(): return_value=[specific_server.server_id, other_server.server_id] ) - with patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_allowed_mcp_servers", - mock_get_allowed, - ), patch( - "litellm.proxy._experimental.mcp_server.server.MCPRequestHandler._get_mcp_servers_from_access_groups", - mock_db_lookup, - ), patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager._get_tools_from_server", - mock_get_tools_spy, + with ( + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_allowed_mcp_servers", + mock_get_allowed, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.MCPRequestHandler._get_mcp_servers_from_access_groups", + mock_db_lookup, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager._get_tools_from_server", + mock_get_tools_spy, + ), ): mcp_servers_from_path = _get_mcp_servers_in_path(test_path) @@ -1062,17 +1099,21 @@ async def test_oauth2_headers_passed_to_mcp_client(): async def mock_fetch_tools_with_timeout(client, server_name): return [] # Return empty list of tools - with patch.object( - global_mcp_server_manager, - "_create_mcp_client", - side_effect=mock_create_mcp_client, - ) as mock_create_client, patch.object( - global_mcp_server_manager, - "_fetch_tools_with_timeout", - side_effect=mock_fetch_tools_with_timeout, - ), patch( - "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", - AsyncMock(return_value=[oauth2_server]), + with ( + patch.object( + global_mcp_server_manager, + "_create_mcp_client", + side_effect=mock_create_mcp_client, + ) as mock_create_client, + patch.object( + global_mcp_server_manager, + "_fetch_tools_with_timeout", + side_effect=mock_fetch_tools_with_timeout, + ), + patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", + AsyncMock(return_value=[oauth2_server]), + ), ): # Call _get_tools_from_mcp_servers which should eventually call _create_mcp_client await _get_tools_from_mcp_servers( @@ -1138,7 +1179,10 @@ async def test_list_tools_single_server_unprefixed_names(): mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"]) mock_manager.get_mcp_server_by_id = MagicMock(return_value=server) # Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering) - mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0) + mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: ( + server_ids, + 0, + ) async def mock_get_tools_from_server( server, @@ -1216,7 +1260,10 @@ async def test_list_tools_multiple_servers_prefixed_names(): server1 if server_id == "server1" else server2 ) # Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering) - mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0) + mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: ( + server_ids, + 0, + ) async def mock_get_tools_from_server( server, @@ -1270,12 +1317,15 @@ async def test_mcp_manager_allows_public_servers_without_permissions(): ) manager.registry = {public_server.server_id: public_server} - with patch( - "litellm.proxy.management_endpoints.common_utils._user_has_admin_view", - return_value=False, - ), patch( - "litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers", - AsyncMock(return_value=[]), + with ( + patch( + "litellm.proxy.management_endpoints.common_utils._user_has_admin_view", + return_value=False, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers", + AsyncMock(return_value=[]), + ), ): allowed = await manager.get_allowed_mcp_servers(UserAPIKeyAuth()) @@ -1302,12 +1352,15 @@ async def test_mcp_manager_returns_public_when_permission_lookup_fails(): ) manager.registry = {public_server.server_id: public_server} - with patch( - "litellm.proxy.management_endpoints.common_utils._user_has_admin_view", - return_value=False, - ), patch( - "litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers", - AsyncMock(side_effect=Exception("boom")), + with ( + patch( + "litellm.proxy.management_endpoints.common_utils._user_has_admin_view", + return_value=False, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers", + AsyncMock(side_effect=Exception("boom")), + ), ): allowed = await manager.get_allowed_mcp_servers(UserAPIKeyAuth()) @@ -1342,12 +1395,15 @@ async def test_mcp_manager_merges_public_and_restricted_servers(): scoped_server.server_id: scoped_server, } - with patch( - "litellm.proxy.management_endpoints.common_utils._user_has_admin_view", - return_value=False, - ), patch( - "litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers", - AsyncMock(return_value=["restricted"]), + with ( + patch( + "litellm.proxy.management_endpoints.common_utils._user_has_admin_view", + return_value=False, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPRequestHandler.get_allowed_mcp_servers", + AsyncMock(return_value=["restricted"]), + ), ): allowed = await manager.get_allowed_mcp_servers(UserAPIKeyAuth()) @@ -1399,12 +1455,15 @@ async def test_call_mcp_tool_user_unauthorized_access(): return another_server_obj return None - with patch( - "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.MCPRequestHandler.get_allowed_mcp_servers", - AsyncMock(return_value=["allowed_server", "another_server"]), - ), patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_id", - side_effect=mock_get_server_by_id, + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.MCPRequestHandler.get_allowed_mcp_servers", + AsyncMock(return_value=["allowed_server", "another_server"]), + ), + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_id", + side_effect=mock_get_server_by_id, + ), ): # Try to call a tool from "restricted_server" - should raise HTTPException with 403 status with pytest.raises(HTTPException) as exc_info: @@ -1467,7 +1526,10 @@ async def test_list_tools_filters_by_key_team_permissions(): mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"]) mock_manager.get_mcp_server_by_id = lambda server_id: server # Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering) - mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0) + mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: ( + server_ids, + 0, + ) async def mock_get_tools_from_server( server, @@ -1573,7 +1635,10 @@ async def test_list_tools_with_team_tool_permissions_inheritance(): mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"]) mock_manager.get_mcp_server_by_id = lambda server_id: server # Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering) - mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0) + mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: ( + server_ids, + 0, + ) async def mock_get_tools_from_server( server, @@ -1665,7 +1730,10 @@ async def test_list_tools_with_no_tool_permissions_shows_all(): mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"]) mock_manager.get_mcp_server_by_id = lambda server_id: server # Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering) - mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0) + mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: ( + server_ids, + 0, + ) async def mock_get_tools_from_server( server, @@ -1760,7 +1828,10 @@ async def test_list_tools_strips_prefix_when_matching_permissions(): mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["gitmcp_server"]) mock_manager.get_mcp_server_by_id = MagicMock(return_value=server) # Mock filter_server_ids_by_ip to return server_ids unchanged (no IP filtering) - mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: (server_ids, 0) + mock_manager.filter_server_ids_by_ip_with_info = lambda server_ids, client_ip: ( + server_ids, + 0, + ) async def mock_get_tools_from_server( server, @@ -2002,12 +2073,15 @@ class TestMCPServerManagerReload: mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock( return_value=[db_row] ) - with patch( - "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", - return_value=mock_prisma, - ), patch.object( - manager, "build_mcp_server_from_table", AsyncMock() - ) as mock_build: + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", + return_value=mock_prisma, + ), + patch.object( + manager, "build_mcp_server_from_table", AsyncMock() + ) as mock_build, + ): await manager.reload_servers_from_database() mock_build.assert_not_awaited() @@ -2045,14 +2119,17 @@ class TestMCPServerManagerReload: mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock( return_value=[db_row] ) - with patch( - "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", - return_value=mock_prisma, - ), patch.object( - manager, - "build_mcp_server_from_table", - AsyncMock(return_value=rebuilt_server), - ) as mock_build: + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", + return_value=mock_prisma, + ), + patch.object( + manager, + "build_mcp_server_from_table", + AsyncMock(return_value=rebuilt_server), + ) as mock_build, + ): await manager.reload_servers_from_database() mock_build.assert_awaited_once_with(db_row) @@ -2090,26 +2167,32 @@ async def test_call_mcp_tool_logs_failure_via_post_call_failure_hook(): user_auth = UserAPIKeyAuth(api_key="test-key", user_id="test-user") - with patch.object( - global_mcp_server_manager, - "get_allowed_mcp_servers", - new_callable=AsyncMock, - return_value=[mock_server.server_id], - ), patch.object( - global_mcp_server_manager, - "get_mcp_server_by_id", - return_value=mock_server, - ), patch( - "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers_from_mcp_server_names", - new_callable=AsyncMock, - return_value=[mock_server], - ), patch( - "litellm.proxy._experimental.mcp_server.server.execute_mcp_tool", - new_callable=AsyncMock, - side_effect=Exception("boom"), - ), patch( - "litellm.proxy.proxy_server.proxy_logging_obj", - proxy_logging_mock, + with ( + patch.object( + global_mcp_server_manager, + "get_allowed_mcp_servers", + new_callable=AsyncMock, + return_value=[mock_server.server_id], + ), + patch.object( + global_mcp_server_manager, + "get_mcp_server_by_id", + return_value=mock_server, + ), + patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers_from_mcp_server_names", + new_callable=AsyncMock, + return_value=[mock_server], + ), + patch( + "litellm.proxy._experimental.mcp_server.server.execute_mcp_tool", + new_callable=AsyncMock, + side_effect=Exception("boom"), + ), + patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + proxy_logging_mock, + ), ): with pytest.raises(Exception): await call_mcp_tool( @@ -2157,23 +2240,30 @@ async def test_get_tools_from_mcp_servers_logs_list_tools_to_spendlogs_when_enab dummy_logging_obj.model_call_details = {"metadata": {"spend_logs_metadata": {}}} dummy_logging_obj.async_success_handler = AsyncMock() - with patch( - "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", - new=AsyncMock(return_value=[server_a]), - ), patch( - "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", - return_value=(None, None), - ), patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", - ) as mock_manager, patch( - "litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools", - side_effect=lambda tools, _server: tools, - ), patch( - "litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions", - new=AsyncMock(side_effect=lambda tools, **_: tools), - ), patch( - "litellm.proxy._experimental.mcp_server.server.function_setup", - return_value=(dummy_logging_obj, None), + with ( + patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", + new=AsyncMock(return_value=[server_a]), + ), + patch( + "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", + return_value=(None, None), + ), + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + ) as mock_manager, + patch( + "litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools", + side_effect=lambda tools, _server: tools, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions", + new=AsyncMock(side_effect=lambda tools, **_: tools), + ), + patch( + "litellm.proxy._experimental.mcp_server.server.function_setup", + return_value=(dummy_logging_obj, None), + ), ): mock_manager._get_tools_from_server = AsyncMock(return_value=[tool_1]) @@ -2188,7 +2278,9 @@ async def test_get_tools_from_mcp_servers_logs_list_tools_to_spendlogs_when_enab assert tools == [tool_1] dummy_logging_obj.async_success_handler.assert_awaited_once() - assert dummy_logging_obj.async_success_handler.await_args.kwargs["result"] == [tool_1] + assert dummy_logging_obj.async_success_handler.await_args.kwargs["result"] == [ + tool_1 + ] spend_meta = dummy_logging_obj.model_call_details["metadata"]["spend_logs_metadata"] assert spend_meta["tool_count_total"] == 1 @@ -2381,26 +2473,34 @@ async def test_get_tools_from_mcp_servers_injects_stored_oauth2_token(): oauth2_server.extra_headers = None # Simulate the DB returning a valid credential for this user+server - prefetched_creds = {SERVER_ID: {"access_token": STORED_TOKEN, "server_id": SERVER_ID}} + prefetched_creds = { + SERVER_ID: {"access_token": STORED_TOKEN, "server_id": SERVER_ID} + } tool_1 = MagicMock() tool_1.name = "atlassian_test-search" - with patch( - "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", - new=AsyncMock(return_value=[oauth2_server]), - ), patch( - # Patch the bulk prefetch so no real DB connection is needed - "litellm.proxy._experimental.mcp_server.server._prefetch_oauth_creds_for_user", - new=AsyncMock(return_value=prefetched_creds), - ) as mock_prefetch, patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", - ) as mock_manager, patch( - "litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools", - side_effect=lambda tools, _server: tools, - ), patch( - "litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions", - new=AsyncMock(side_effect=lambda tools, **_: tools), + with ( + patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", + new=AsyncMock(return_value=[oauth2_server]), + ), + patch( + # Patch the bulk prefetch so no real DB connection is needed + "litellm.proxy._experimental.mcp_server.server._prefetch_oauth_creds_for_user", + new=AsyncMock(return_value=prefetched_creds), + ) as mock_prefetch, + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + ) as mock_manager, + patch( + "litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools", + side_effect=lambda tools, _server: tools, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions", + new=AsyncMock(side_effect=lambda tools, **_: tools), + ), ): mock_manager._get_tools_from_server = AsyncMock(return_value=[tool_1]) @@ -2421,3 +2521,201 @@ async def test_get_tools_from_mcp_servers_injects_stored_oauth2_token(): assert call_kwargs["extra_headers"] == {"Authorization": f"Bearer {STORED_TOKEN}"} assert tools == [tool_1] + + +# --------------------------------------------------------------------------- +# _merge_gateway_initialize_instructions + ContextVar / InitializationOptions +# --------------------------------------------------------------------------- + + +def _make_instruction_server( + server_id="s1", + name="s1", + *, + alias=None, + server_name=None, + instructions=None, + spec_path=None, + url="https://example.com", +): + return MCPServer( + server_id=server_id, + name=name, + alias=alias, + server_name=server_name, + url=url, + transport=MCPTransport.http, + instructions=instructions, + spec_path=spec_path, + ) + + +class TestMergeGatewayInitializeInstructions: + """Tests for _merge_gateway_initialize_instructions.""" + + def _merge(self, servers): + try: + from litellm.proxy._experimental.mcp_server.server import ( + _merge_gateway_initialize_instructions, + ) + except ImportError: + pytest.skip("MCP server not available") + return _merge_gateway_initialize_instructions(servers) + + def test_empty_server_list_returns_none(self): + """No servers yields no instructions.""" + assert self._merge([]) is None + + def test_single_server_yaml_instructions(self): + """A single server with YAML instructions returns them verbatim.""" + s = _make_instruction_server(instructions="Use add() for sums.") + assert self._merge([s]) == "Use add() for sums." + + def test_yaml_instructions_strips_whitespace(self): + """Leading/trailing whitespace is stripped.""" + s = _make_instruction_server(instructions=" padded \n") + assert self._merge([s]) == "padded" + + def test_yaml_override_beats_upstream_cache(self): + """YAML/DB instructions take precedence over upstream cache.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + + global_mcp_server_manager._upstream_initialize_instructions_by_server_id[ + "s1" + ] = "upstream" + try: + s = _make_instruction_server(instructions="yaml wins") + assert self._merge([s]) == "yaml wins" + finally: + global_mcp_server_manager._upstream_initialize_instructions_by_server_id.pop( + "s1", None + ) + + def test_upstream_cache_used_when_no_yaml(self): + """Upstream cached instructions are used when no YAML override is set.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + + global_mcp_server_manager._upstream_initialize_instructions_by_server_id[ + "s1" + ] = "from upstream" + try: + s = _make_instruction_server(instructions=None) + assert self._merge([s]) == "from upstream" + finally: + global_mcp_server_manager._upstream_initialize_instructions_by_server_id.pop( + "s1", None + ) + + def test_spec_path_servers_skipped(self): + """OpenAPI (spec_path) servers do not contribute instructions.""" + s = _make_instruction_server(spec_path="/openapi.json", url=None) + assert self._merge([s]) is None + + def test_no_instructions_no_cache_returns_none(self): + """Server with no instructions and no cache yields None.""" + s = _make_instruction_server() + assert self._merge([s]) is None + + def test_multiple_servers_merged_with_labels(self): + """Multiple servers get label-prefixed and separator-joined.""" + s1 = _make_instruction_server( + server_id="a", name="a", alias="Alpha", instructions="instr A" + ) + s2 = _make_instruction_server( + server_id="b", name="b", alias="Beta", instructions="instr B" + ) + result = self._merge([s1, s2]) + assert result is not None + assert "[Alpha]" in result and "[Beta]" in result + assert "instr A" in result and "instr B" in result + assert "---" in result + + def test_single_server_no_label_wrapping(self): + """A single server's instructions are not wrapped with a label.""" + s = _make_instruction_server(alias="MyServer", instructions="single") + result = self._merge([s]) + assert result == "single" + assert "[MyServer]" not in result + + def test_mixed_yaml_cache_specpath(self): + """YAML, upstream-cache, and spec_path servers are handled correctly together.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + + global_mcp_server_manager._upstream_initialize_instructions_by_server_id[ + "c" + ] = "cached C" + try: + s_yaml = _make_instruction_server( + server_id="a", name="a", alias="A", instructions="yaml A" + ) + s_spec = _make_instruction_server( + server_id="b", name="b", alias="B", spec_path="/spec.json", url=None + ) + s_cached = _make_instruction_server(server_id="c", name="c", alias="C") + result = self._merge([s_yaml, s_spec, s_cached]) + assert "yaml A" in result + assert "cached C" in result + assert "[B]" not in result + finally: + global_mcp_server_manager._upstream_initialize_instructions_by_server_id.pop( + "c", None + ) + + +class TestGatewayCreateInitializationOptions: + """Tests for the patched server.create_initialization_options via ContextVar.""" + + def test_no_contextvar_returns_default_options(self): + """When ContextVar is None, instructions are absent.""" + try: + from litellm.proxy._experimental.mcp_server.mcp_context import ( + _mcp_gateway_initialize_instructions, + ) + from litellm.proxy._experimental.mcp_server.server import server + except ImportError: + pytest.skip("MCP server not available") + + tok = _mcp_gateway_initialize_instructions.set(None) + try: + opts = server.create_initialization_options() + assert getattr(opts, "instructions", None) is None + finally: + _mcp_gateway_initialize_instructions.reset(tok) + + def test_contextvar_set_injects_instructions(self): + """When ContextVar has a value, it appears in InitializationOptions.""" + try: + from litellm.proxy._experimental.mcp_server.mcp_context import ( + _mcp_gateway_initialize_instructions, + ) + from litellm.proxy._experimental.mcp_server.server import server + except ImportError: + pytest.skip("MCP server not available") + + tok = _mcp_gateway_initialize_instructions.set("hello from merge") + try: + opts = server.create_initialization_options() + assert opts.instructions == "hello from merge" + finally: + _mcp_gateway_initialize_instructions.reset(tok) + + def test_contextvar_reset_removes_instructions(self): + """After resetting the ContextVar, instructions disappear.""" + try: + from litellm.proxy._experimental.mcp_server.mcp_context import ( + _mcp_gateway_initialize_instructions, + ) + from litellm.proxy._experimental.mcp_server.server import server + except ImportError: + pytest.skip("MCP server not available") + + tok = _mcp_gateway_initialize_instructions.set("temporary") + _mcp_gateway_initialize_instructions.reset(tok) + opts = server.create_initialization_options() + assert getattr(opts, "instructions", None) is None diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py index 656a9c616e..ac5349e710 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py @@ -43,10 +43,10 @@ def _reload_mcp_manager_module(): # After reload, server.py still holds a stale reference to the old # global_mcp_server_manager. Update it so tests that exercise server.py # functions (e.g. _get_tools_from_mcp_servers) use the fresh instance. - server_module = sys.modules.get( - "litellm.proxy._experimental.mcp_server.server" - ) - if server_module is not None and hasattr(server_module, "global_mcp_server_manager"): + server_module = sys.modules.get("litellm.proxy._experimental.mcp_server.server") + if server_module is not None and hasattr( + server_module, "global_mcp_server_manager" + ): server_module.global_mcp_server_manager = reloaded.global_mcp_server_manager return reloaded @@ -223,9 +223,7 @@ class TestMCPServerManager: with caplog.at_level(logging.WARNING, logger="LiteLLM"): await manager.load_servers_from_config(config) - assert any( - "invalid alias 'bad/name'" in message for message in caplog.messages - ) + assert any("invalid alias 'bad/name'" in message for message in caplog.messages) @pytest.mark.asyncio async def test_load_servers_from_config_accepts_valid_alias(self, caplog): @@ -492,7 +490,12 @@ class TestMCPServerManager: mock_client = AsyncMock() mock_client.list_prompts = AsyncMock(return_value=[mock_prompt]) - with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client): + with patch.object( + manager, + "_create_mcp_client", + new_callable=AsyncMock, + return_value=mock_client, + ): prompts = await manager.get_prompts_from_server(server, add_prefix=True) mock_client.list_prompts.assert_awaited_once() @@ -520,7 +523,12 @@ class TestMCPServerManager: mock_client = AsyncMock() mock_client.get_prompt = AsyncMock(return_value=mock_result) - with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client): + with patch.object( + manager, + "_create_mcp_client", + new_callable=AsyncMock, + return_value=mock_client, + ): result = await manager.get_prompt_from_server( server=server, prompt_name="hello", @@ -551,13 +559,23 @@ class TestMCPServerManager: mock_client = AsyncMock() mock_resources = [Resource(name="file", uri="https://example.com/file")] mock_client.list_resources = AsyncMock(return_value=mock_resources) - prefixed_resources = [Resource(name="alias-server-file", uri="https://example.com/file")] + prefixed_resources = [ + Resource(name="alias-server-file", uri="https://example.com/file") + ] - with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client) as mock_create_client, patch.object( - manager, - "_create_prefixed_resources", - return_value=prefixed_resources, - ) as mock_prefix: + with ( + patch.object( + manager, + "_create_mcp_client", + new_callable=AsyncMock, + return_value=mock_client, + ) as mock_create_client, + patch.object( + manager, + "_create_prefixed_resources", + return_value=prefixed_resources, + ) as mock_prefix, + ): result = await manager.get_resources_from_server( server=server, mcp_auth_header="auth", @@ -602,11 +620,19 @@ class TestMCPServerManager: ) ] - with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client) as mock_create_client, patch.object( - manager, - "_create_prefixed_resource_templates", - return_value=prefixed_templates, - ) as mock_prefix: + with ( + patch.object( + manager, + "_create_mcp_client", + new_callable=AsyncMock, + return_value=mock_client, + ) as mock_create_client, + patch.object( + manager, + "_create_prefixed_resource_templates", + return_value=prefixed_templates, + ) as mock_prefix, + ): result = await manager.get_resource_templates_from_server( server=server, mcp_auth_header="auth", @@ -650,7 +676,12 @@ class TestMCPServerManager: ) mock_client.read_resource = AsyncMock(return_value=read_result) - with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client) as mock_create_client: + with patch.object( + manager, + "_create_mcp_client", + new_callable=AsyncMock, + return_value=mock_client, + ) as mock_create_client: result = await manager.read_resource_from_server( server=server, url="https://example.com/resource", @@ -661,7 +692,9 @@ class TestMCPServerManager: mock_create_client.assert_called_once() called_kwargs = mock_create_client.call_args.kwargs assert called_kwargs["extra_headers"] == {"X-Test": "1", "X-Static": "1"} - mock_client.read_resource.assert_awaited_once_with("https://example.com/resource") + mock_client.read_resource.assert_awaited_once_with( + "https://example.com/resource" + ) assert result is read_result @pytest.mark.asyncio @@ -724,22 +757,27 @@ class TestMCPServerManager: registration_url=None, ) - with patch( - "litellm.proxy._experimental.mcp_server.mcp_server_manager.get_async_httpx_client", - return_value=mock_client, - ), patch.object( - manager, - "_fetch_oauth_metadata_from_resource", - AsyncMock(return_value=([], None)), - ), patch.object( - manager, - "_attempt_well_known_discovery", - AsyncMock(return_value=([], None)), - ), patch.object( - manager, - "_fetch_authorization_server_metadata", - AsyncMock(return_value=mock_metadata), - ) as mock_fetch_auth: + with ( + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.get_async_httpx_client", + return_value=mock_client, + ), + patch.object( + manager, + "_fetch_oauth_metadata_from_resource", + AsyncMock(return_value=([], None)), + ), + patch.object( + manager, + "_attempt_well_known_discovery", + AsyncMock(return_value=([], None)), + ), + patch.object( + manager, + "_fetch_authorization_server_metadata", + AsyncMock(return_value=mock_metadata), + ) as mock_fetch_auth, + ): result = await manager._descovery_metadata(server_url) mock_fetch_auth.assert_awaited_once_with(["https://example.com"]) @@ -779,9 +817,8 @@ class TestMCPServerManager: assert server.scopes == ["config"] # config overrides discovery assert server.authorization_url == "https://config.example.com/auth" assert server.token_url == "https://discovered.example.com/token" - assert ( - server.registration_url == "https://discovered.example.com/register" - ) + assert server.registration_url == "https://discovered.example.com/register" + @pytest.mark.asyncio async def test_config_oauth_initialize_tool_name_to_mcp_server_name_mapping(self): manager = MCPServerManager() @@ -801,7 +838,7 @@ class TestMCPServerManager: # Initialize the tool mapping await manager._initialize_tool_name_to_mcp_server_name_mapping() assert manager.tool_name_to_mcp_server_name_mapping == {} - + @pytest.mark.asyncio async def test_list_tools_handles_missing_server_alias(self): """Test that list_tools handles servers without alias gracefully""" @@ -1017,7 +1054,9 @@ class TestMCPServerManager: # Capture the extra_headers passed to _create_mcp_client captured_extra_headers = None - async def capture_create_mcp_client(server, mcp_auth_header, extra_headers, stdio_env): + async def capture_create_mcp_client( + server, mcp_auth_header, extra_headers, stdio_env + ): nonlocal captured_extra_headers captured_extra_headers = extra_headers return mock_client @@ -1314,15 +1353,19 @@ class TestMCPServerManager: return tool_func - with patch( - "litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.create_tool_function", - side_effect=fake_create_tool_function, - ), patch( - "litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.build_input_schema", - return_value={"type": "object", "properties": {}, "required": []}, - ), patch( - "litellm.proxy._experimental.mcp_server.tool_registry.global_mcp_tool_registry.register_tool", - return_value=None, + with ( + patch( + "litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.create_tool_function", + side_effect=fake_create_tool_function, + ), + patch( + "litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.build_input_schema", + return_value={"type": "object", "properties": {}, "required": []}, + ), + patch( + "litellm.proxy._experimental.mcp_server.tool_registry.global_mcp_tool_registry.register_tool", + return_value=None, + ), ): await manager._register_openapi_tools( spec_path=str(spec_path), @@ -2161,7 +2204,9 @@ class TestMCPServerManager: # Register the server and map a tool to it manager.registry = {"test-server": server} manager.tool_name_to_mcp_server_name_mapping["test_tool"] = "test-server" - manager.tool_name_to_mcp_server_name_mapping["test-server-test_tool"] = "test-server" + manager.tool_name_to_mcp_server_name_mapping["test-server-test_tool"] = ( + "test-server" + ) # Create mock client that tracks call_tool usage mock_client = AsyncMock() @@ -2252,11 +2297,16 @@ class TestMCPServerManager: # Verify MCPRequestHandler.get_allowed_mcp_servers was called with user_api_key_auth mock_get_allowed.assert_called_once() call_args = mock_get_allowed.call_args - assert call_args[0][0] is user_api_key_auth # First positional arg should be user_api_key_auth + assert ( + call_args[0][0] is user_api_key_auth + ) # First positional arg should be user_api_key_auth assert call_args[0][0].user_id == "user-123" assert call_args[0][0].object_permission_id == "perm_123" assert call_args[0][0].object_permission is not None - assert call_args[0][0].object_permission.mcp_servers == ["test_server_1", "test_server_2"] + assert call_args[0][0].object_permission.mcp_servers == [ + "test_server_1", + "test_server_2", + ] # Verify result contains the expected servers assert "test_server_1" in result @@ -2483,5 +2533,82 @@ class TestHasClientCredentialsOAuth2Flow: assert server.needs_user_oauth_token is False +# --------------------------------------------------------------------------- +# Upstream initialize-instructions cache +# --------------------------------------------------------------------------- + + +class TestMCPServerManagerUpstreamInstructionsCache: + """Tests for the upstream initialize-instructions cache.""" + + def test_get_returns_none_when_empty(self): + """Empty cache returns None for any key.""" + manager = MCPServerManager() + assert ( + manager._upstream_initialize_instructions_by_server_id.get("nonexistent") + is None + ) + + def test_remember_stores_stripped_value(self): + """_remember_upstream_initialize_instructions stores a stripped string.""" + manager = MCPServerManager() + fake_server = MagicMock(server_id="srv") + fake_client = MagicMock(_last_initialize_instructions=" hello \n") + manager._remember_upstream_initialize_instructions(fake_server, fake_client) + assert ( + manager._upstream_initialize_instructions_by_server_id.get("srv") == "hello" + ) + + def test_remember_ignores_empty_string(self): + """Whitespace-only instructions are not stored.""" + manager = MCPServerManager() + fake_server = MagicMock(server_id="srv") + fake_client = MagicMock(_last_initialize_instructions=" ") + manager._remember_upstream_initialize_instructions(fake_server, fake_client) + assert manager._upstream_initialize_instructions_by_server_id.get("srv") is None + + def test_remember_ignores_none(self): + """None instructions are not stored.""" + manager = MCPServerManager() + fake_server = MagicMock(server_id="srv") + fake_client = MagicMock(_last_initialize_instructions=None) + manager._remember_upstream_initialize_instructions(fake_server, fake_client) + assert manager._upstream_initialize_instructions_by_server_id.get("srv") is None + + @pytest.mark.asyncio + async def test_load_servers_from_config_clears_cache(self): + """Reloading config clears any previously cached upstream instructions.""" + manager = MCPServerManager() + manager._upstream_initialize_instructions_by_server_id["old"] = "stale" + await manager.load_servers_from_config( + mcp_servers_config={ + "fresh_srv": { + "url": "https://example.com", + "instructions": "from yaml", + } + } + ) + assert manager._upstream_initialize_instructions_by_server_id.get("old") is None + + @pytest.mark.asyncio + async def test_load_servers_reads_instructions_from_config(self): + """instructions field from YAML config is persisted on the MCPServer.""" + manager = MCPServerManager() + await manager.load_servers_from_config( + mcp_servers_config={ + "srv_a": { + "url": "https://a.example.com", + "instructions": "A instructions", + }, + "srv_b": { + "url": "https://b.example.com", + }, + } + ) + by_name = {s.server_name: s for s in manager.config_mcp_servers.values()} + assert "srv_a" in by_name and by_name["srv_a"].instructions == "A instructions" + assert "srv_b" in by_name and by_name["srv_b"].instructions is None + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_sigv4_auth.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_sigv4_auth.py index 7c142e3a77..32b988ddb2 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_sigv4_auth.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_sigv4_auth.py @@ -399,7 +399,9 @@ class TestMCPServerManagerSigV4: server = next(iter(manager.config_mcp_servers.values())) assert server.auth_type == MCPAuth.aws_sigv4 assert server.aws_access_key_id == "AKIAIOSFODNN7EXAMPLE" - assert server.aws_secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + assert ( + server.aws_secret_access_key == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + ) assert server.aws_region_name == "us-east-1" assert server.aws_service_name == "bedrock-agentcore" @@ -529,7 +531,9 @@ class TestMCPServerManagerSigV4: "aws_session_name": "my-session", } - result = manager._extract_aws_credentials(creds, credentials_are_encrypted=False) + result = manager._extract_aws_credentials( + creds, credentials_are_encrypted=False + ) assert result["aws_role_name"] == "arn:aws:iam::123456789012:role/TestRole" assert result["aws_session_name"] == "my-session" @@ -615,12 +619,15 @@ class TestCredentialMergeOnUpdate: credentials={"aws_region_name": "eu-west-1"}, ) - with patch( - "litellm.proxy._experimental.mcp_server.db._get_salt_key", - return_value=None, - ), patch( - "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", - side_effect=lambda value, new_encryption_key: value, + with ( + patch( + "litellm.proxy._experimental.mcp_server.db._get_salt_key", + return_value=None, + ), + patch( + "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", + side_effect=lambda value, new_encryption_key: value, + ), ): await update_mcp_server(mock_prisma, data, "test-user") @@ -685,12 +692,15 @@ class TestCredentialMergeOnUpdate: credentials={"aws_region_name": "us-east-1"}, ) - with patch( - "litellm.proxy._experimental.mcp_server.db._get_salt_key", - return_value=None, - ), patch( - "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", - side_effect=lambda value, new_encryption_key: value, + with ( + patch( + "litellm.proxy._experimental.mcp_server.db._get_salt_key", + return_value=None, + ), + patch( + "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", + side_effect=lambda value, new_encryption_key: value, + ), ): await update_mcp_server(mock_prisma, data, "test-user") @@ -728,12 +738,15 @@ class TestCredentialMergeOnUpdate: credentials={"auth_value": "my-key"}, ) - with patch( - "litellm.proxy._experimental.mcp_server.db._get_salt_key", - return_value=None, - ), patch( - "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", - side_effect=lambda value, new_encryption_key: f"enc:{value}", + with ( + patch( + "litellm.proxy._experimental.mcp_server.db._get_salt_key", + return_value=None, + ), + patch( + "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", + side_effect=lambda value, new_encryption_key: f"enc:{value}", + ), ): await update_mcp_server(mock_prisma, data, "test-user") @@ -772,12 +785,15 @@ class TestCredentialMergeOnUpdate: credentials={"scopes": ["read", "write"]}, ) - with patch( - "litellm.proxy._experimental.mcp_server.db._get_salt_key", - return_value=None, - ), patch( - "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", - side_effect=lambda value, new_encryption_key: value, + with ( + patch( + "litellm.proxy._experimental.mcp_server.db._get_salt_key", + return_value=None, + ), + patch( + "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", + side_effect=lambda value, new_encryption_key: value, + ), ): await update_mcp_server(mock_prisma, data, "test-user") @@ -803,7 +819,9 @@ class TestSigV4BuildFromTable: table_record.server_name = "sigv4_server" table_record.alias = None table_record.description = None - table_record.url = "https://bedrock-agentcore.us-east-1.amazonaws.com/invocations" + table_record.url = ( + "https://bedrock-agentcore.us-east-1.amazonaws.com/invocations" + ) table_record.spec_path = None table_record.transport = "http" table_record.auth_type = "aws_sigv4" @@ -838,6 +856,7 @@ class TestSigV4BuildFromTable: table_record.tool_name_to_description = None table_record.byok_api_key_help_url = None table_record.oauth2_flow = None + table_record.instructions = None manager = MCPServerManager() @@ -895,6 +914,7 @@ class TestSigV4BuildFromTable: table_record.tool_name_to_description = None table_record.byok_api_key_help_url = None table_record.oauth2_flow = None + table_record.instructions = None manager = MCPServerManager() @@ -934,7 +954,9 @@ class TestDecryptCredentials: with patch( "litellm.proxy._experimental.mcp_server.db.decrypt_value_helper", - side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace("enc:", ""), + side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace( + "enc:", "" + ), ): result = decrypt_credentials(credentials=creds) @@ -956,7 +978,9 @@ class TestDecryptCredentials: with patch( "litellm.proxy._experimental.mcp_server.db.decrypt_value_helper", - side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace("enc:", ""), + side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace( + "enc:", "" + ), ): result = decrypt_credentials(credentials=creds) @@ -988,15 +1012,21 @@ class TestRotateCredentials: ) mock_prisma.db.litellm_mcpservertable.update = AsyncMock() - with patch( - "litellm.proxy._experimental.mcp_server.db._get_salt_key", - return_value="old-key", - ), patch( - "litellm.proxy._experimental.mcp_server.db.decrypt_value_helper", - side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace("enc_old:", ""), - ), patch( - "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", - side_effect=lambda value, new_encryption_key: f"enc_new:{value}", + with ( + patch( + "litellm.proxy._experimental.mcp_server.db._get_salt_key", + return_value="old-key", + ), + patch( + "litellm.proxy._experimental.mcp_server.db.decrypt_value_helper", + side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace( + "enc_old:", "" + ), + ), + patch( + "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", + side_effect=lambda value, new_encryption_key: f"enc_new:{value}", + ), ): await rotate_mcp_server_credentials_master_key( mock_prisma, "admin", "new-key" diff --git a/tests/test_litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py b/tests/test_litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py index e83fd75c3a..eb795fbc01 100644 --- a/tests/test_litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py +++ b/tests/test_litellm/proxy/db/db_transaction_queue/test_pod_lock_manager.py @@ -307,3 +307,83 @@ async def test_lock_takeover_race_condition(mock_redis): cronjob_id="test_job", ) assert result2 == False + + +@pytest.mark.asyncio +async def test_release_lock_uses_atomic_compare_delete_script_when_available( + pod_lock_manager, mock_redis +): + """ + Test that release_lock prefers atomic compare-and-delete Lua script when + redis cache exposes script registration. + """ + script_callable = AsyncMock(return_value=1) + mock_redis.async_register_script = MagicMock(return_value=script_callable) + + await pod_lock_manager.release_lock(cronjob_id="test_job") + + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") + mock_redis.async_register_script.assert_called_once_with( + PodLockManager._COMPARE_AND_DELETE_LOCK_SCRIPT + ) + script_callable.assert_called_once_with( + keys=[lock_key], args=[pod_lock_manager.pod_id] + ) + mock_redis.async_get_cache.assert_not_called() + mock_redis.async_delete_cache.assert_not_called() + + +@pytest.mark.asyncio +async def test_release_lock_reuses_registered_script(pod_lock_manager, mock_redis): + """ + Test script registration is cached on manager instance and reused. + """ + script_callable = AsyncMock(return_value=0) + mock_redis.async_register_script = MagicMock(return_value=script_callable) + + await pod_lock_manager.release_lock(cronjob_id="test_job") + await pod_lock_manager.release_lock(cronjob_id="test_job") + + assert mock_redis.async_register_script.call_count == 1 + + +@pytest.mark.asyncio +async def test_release_lock_lua_path_emits_released_event( + pod_lock_manager, mock_redis +): + """ + Test that _emit_released_lock_event is called when the Lua path returns 1 + (successful release). + """ + script_callable = AsyncMock(return_value=1) + mock_redis.async_register_script = MagicMock(return_value=script_callable) + + with patch.object(pod_lock_manager, "_emit_released_lock_event") as mock_emit: + await pod_lock_manager.release_lock(cronjob_id="test_job") + + mock_emit.assert_called_once_with( + cronjob_id="test_job", pod_id=pod_lock_manager.pod_id + ) + + +@pytest.mark.asyncio +async def test_release_lock_falls_back_to_get_del_when_lua_execution_fails( + pod_lock_manager, mock_redis +): + """ + Test that release_lock falls back to GET+DEL when Lua script execution + raises (e.g. Redis restart cleared loaded scripts). + """ + script_callable = AsyncMock(side_effect=Exception("NOSCRIPT")) + mock_redis.async_register_script = MagicMock(return_value=script_callable) + mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id + mock_redis.async_delete_cache.return_value = 1 + + await pod_lock_manager.release_lock(cronjob_id="test_job") + + # Lua failed — should have fallen back to GET+DEL + lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job") + mock_redis.async_get_cache.assert_called_once_with(lock_key) + mock_redis.async_delete_cache.assert_called_once_with(lock_key) + # Cached script handle should be reset so next call re-registers + assert pod_lock_manager._release_lock_script is None diff --git a/tests/test_litellm/proxy/db/test_create_views.py b/tests/test_litellm/proxy/db/test_create_views.py new file mode 100644 index 0000000000..1a90b4c204 --- /dev/null +++ b/tests/test_litellm/proxy/db/test_create_views.py @@ -0,0 +1,155 @@ +""" +Tests for create_missing_views exception handling fix. + +Verifies that real DB errors (auth failures, connection errors, etc.) +are re-raised instead of being silently swallowed, while genuine +"view not found" errors still trigger view creation. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, call + + +@pytest.mark.asyncio +async def test_create_views_reraises_connection_error(): + """should re-raise exceptions that are NOT 'does not exist' errors (e.g. connection errors).""" + from litellm.proxy.db.create_views import create_missing_views + + mock_db = MagicMock() + mock_db.query_raw = AsyncMock( + side_effect=Exception("connection refused: unable to connect to database") + ) + mock_db.execute_raw = AsyncMock() + + with pytest.raises(Exception, match="connection refused"): + await create_missing_views(mock_db) + + mock_db.execute_raw.assert_not_called() + + +@pytest.mark.asyncio +async def test_create_views_reraises_permission_error(): + """should re-raise permission denied errors, not treat them as missing views.""" + from litellm.proxy.db.create_views import create_missing_views + + mock_db = MagicMock() + mock_db.query_raw = AsyncMock( + side_effect=Exception( + "permission denied for table LiteLLM_VerificationTokenView" + ) + ) + mock_db.execute_raw = AsyncMock() + + with pytest.raises(Exception, match="permission denied"): + await create_missing_views(mock_db) + + mock_db.execute_raw.assert_not_called() + + +@pytest.mark.asyncio +async def test_create_views_creates_view_on_does_not_exist(): + """should call execute_raw to create view when error contains 'does not exist'.""" + from litellm.proxy.db.create_views import create_missing_views + + mock_db = MagicMock() + mock_db.query_raw = AsyncMock( + side_effect=[ + Exception('relation "LiteLLM_VerificationTokenView" does not exist'), + None, # MonthlyGlobalSpend exists + None, # Last30dKeysBySpend exists + None, # Last30dModelsBySpend exists + None, # MonthlyGlobalSpendPerKey exists + None, # MonthlyGlobalSpendPerUserPerKey exists + None, # DailyTagSpend exists + None, # Last30dTopEndUsersSpend exists + ] + ) + mock_db.execute_raw = AsyncMock(return_value=None) + + await create_missing_views(mock_db) + + mock_db.execute_raw.assert_called_once() + created_sql = mock_db.execute_raw.call_args[0][0] + assert 'CREATE VIEW "LiteLLM_VerificationTokenView"' in created_sql + + +@pytest.mark.asyncio +async def test_create_views_creates_view_on_undefined_error(): + """should treat 'undefined' errors as 'view not found' and attempt creation.""" + from litellm.proxy.db.create_views import create_missing_views + + mock_db = MagicMock() + mock_db.query_raw = AsyncMock( + side_effect=[ + Exception("undefined table LiteLLM_VerificationTokenView"), + None, + None, + None, + None, + None, + None, + None, + ] + ) + mock_db.execute_raw = AsyncMock(return_value=None) + + await create_missing_views(mock_db) + + mock_db.execute_raw.assert_called_once() + + +@pytest.mark.asyncio +async def test_create_views_skips_creation_when_view_exists(): + """should not call execute_raw when all views already exist.""" + from litellm.proxy.db.create_views import create_missing_views + + mock_db = MagicMock() + mock_db.query_raw = AsyncMock(return_value=[{"?column?": 1}]) + mock_db.execute_raw = AsyncMock() + + await create_missing_views(mock_db) + + mock_db.execute_raw.assert_not_called() + + +@pytest.mark.asyncio +async def test_create_views_reraises_undefined_function_error(): + """should re-raise 'undefined function' errors — bare 'undefined' is too broad + and would previously misclassify DB function errors as missing-view signals.""" + from litellm.proxy.db.create_views import create_missing_views + + mock_db = MagicMock() + mock_db.query_raw = AsyncMock( + side_effect=Exception("ERROR: undefined function pg_get_viewdef()") + ) + mock_db.execute_raw = AsyncMock() + + with pytest.raises(Exception, match="undefined function"): + await create_missing_views(mock_db) + + mock_db.execute_raw.assert_not_called() + + +@pytest.mark.asyncio +async def test_create_views_creates_view_on_undefined_table_error(): + """should treat 'undefined table' as a missing-view signal and attempt creation.""" + from litellm.proxy.db.create_views import create_missing_views + + mock_db = MagicMock() + mock_db.query_raw = AsyncMock( + side_effect=[ + Exception('undefined table "LiteLLM_VerificationTokenView"'), + None, + None, + None, + None, + None, + None, + None, + ] + ) + mock_db.execute_raw = AsyncMock(return_value=None) + + await create_missing_views(mock_db) + + mock_db.execute_raw.assert_called_once() diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py index 010ead425c..0472cc565e 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py @@ -1190,6 +1190,483 @@ async def test_bedrock_guardrail_blocked_content_with_masking_enabled(): print("✅ BLOCKED content with masking enabled raises exception correctly") +# ────────────────────────────────────────────────────────────────────────────── +# Null-safety tests for Bedrock guardrail responses +# +# The Bedrock ApplyGuardrail API can return explicit null/None for list fields +# such as "regexes", "piiEntities", "topics", "filters", "customWords", and +# "managedWordLists" when a particular policy category is present in the +# assessment but has no matches. +# +# Python's dict.get("key", []) returns None (NOT []) when the key exists with +# a None value. The `or []` fallback ensures we always iterate over a list. +# +# Without the fix, iterating over None raises: +# TypeError: 'NoneType' object is not iterable +# which surfaces to callers as: +# openai.InternalServerError: Error code: 500 +# {'error': {'message': "Bedrock guardrail failed: 'NoneType' object is not iterable", ...}} +# ────────────────────────────────────────────────────────────────────────────── + + +class TestRedactPiiMatchesNullSafety: + """Tests for _redact_pii_matches handling of null/None list fields from Bedrock API.""" + + @pytest.mark.asyncio + async def test_should_handle_null_regexes_in_sensitive_info_policy(self): + """Bedrock can return regexes: null while piiEntities has data. + + Real-world scenario: guardrail detects PII (e.g. EMAIL) but has no + custom regex patterns configured, so the API returns regexes: null. + """ + response = { + "action": "NONE", + "actionReason": "No action.", + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": [ + { + "action": "NONE", + "detected": True, + "match": "joebloggs@gmail.com", + "type": "EMAIL", + } + ], + "regexes": None, # Explicit null from Bedrock API + }, + } + ], + } + + # Should not raise TypeError: 'NoneType' object is not iterable + redacted = _redact_pii_matches(response) + + # PII match should be redacted + pii = redacted["assessments"][0]["sensitiveInformationPolicy"]["piiEntities"] + assert pii[0]["match"] == "[REDACTED]" + assert pii[0]["type"] == "EMAIL" + + @pytest.mark.asyncio + async def test_should_handle_null_pii_entities_in_sensitive_info_policy(self): + """Bedrock can return piiEntities: null while regexes has data.""" + response = { + "action": "NONE", + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": None, # null from Bedrock API + "regexes": [ + { + "name": "CUSTOM_PATTERN", + "match": "secret-abc-123", + "action": "BLOCKED", + } + ], + }, + } + ], + } + + redacted = _redact_pii_matches(response) + + regexes = redacted["assessments"][0]["sensitiveInformationPolicy"]["regexes"] + assert regexes[0]["match"] == "[REDACTED]" + + @pytest.mark.asyncio + async def test_should_handle_null_custom_words_and_managed_words(self): + """Bedrock can return null for customWords and managedWordLists in wordPolicy.""" + response = { + "action": "NONE", + "assessments": [ + { + "wordPolicy": { + "customWords": None, # null from Bedrock API + "managedWordLists": None, # null from Bedrock API + }, + } + ], + } + + # Should not raise TypeError + redacted = _redact_pii_matches(response) + + # Values should remain None (no crash) + assert redacted["assessments"][0]["wordPolicy"]["customWords"] is None + assert redacted["assessments"][0]["wordPolicy"]["managedWordLists"] is None + + @pytest.mark.asyncio + async def test_should_handle_null_assessments_list(self): + """Bedrock can return assessments: null.""" + response = { + "action": "NONE", + "assessments": None, # null from Bedrock API + } + + # Should not raise TypeError + redacted = _redact_pii_matches(response) + assert redacted["assessments"] is None + + @pytest.mark.asyncio + async def test_should_handle_all_null_policy_sub_lists_together(self): + """All sub-list fields are null at the same time — worst-case scenario.""" + response = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": None, + "regexes": None, + }, + "wordPolicy": { + "customWords": None, + "managedWordLists": None, + }, + "topicPolicy": None, + "contentPolicy": None, + "contextualGroundingPolicy": None, + } + ], + } + + # Should not raise any exception + redacted = _redact_pii_matches(response) + assert redacted is not None + + +class TestShouldRaiseGuardrailBlockedExceptionNullSafety: + """Tests for _should_raise_guardrail_blocked_exception handling of null list fields.""" + + def _create_guardrail(self) -> BedrockGuardrail: + return BedrockGuardrail( + guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT" + ) + + @pytest.mark.asyncio + async def test_should_handle_all_null_policy_sub_lists(self): + """All policy sub-lists are null — should not crash, should return False.""" + guardrail = self._create_guardrail() + + response = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "topicPolicy": { + "topics": None, # null from Bedrock API + }, + "contentPolicy": { + "filters": None, # null + }, + "wordPolicy": { + "customWords": None, # null + "managedWordLists": None, # null + }, + "sensitiveInformationPolicy": { + "piiEntities": None, # null + "regexes": None, # null + }, + "contextualGroundingPolicy": { + "filters": None, # null + }, + } + ], + } + + # No BLOCKED actions found (all lists null) → should return False + result = guardrail._should_raise_guardrail_blocked_exception(response) + assert result is False + + @pytest.mark.asyncio + async def test_should_detect_blocked_despite_other_null_lists(self): + """A mix of null lists and a real BLOCKED action — should still detect it.""" + guardrail = self._create_guardrail() + + response = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "topicPolicy": { + "topics": None, # null — should not crash + }, + "contentPolicy": { + "filters": [ + { + "type": "HATE", + "confidence": "HIGH", + "action": "BLOCKED", + } + ], + }, + "wordPolicy": { + "customWords": None, # null + "managedWordLists": None, # null + }, + "sensitiveInformationPolicy": { + "piiEntities": None, # null + "regexes": None, # null + }, + "contextualGroundingPolicy": None, # entire policy is null + } + ], + } + + # Should return True because contentPolicy has a BLOCKED filter + result = guardrail._should_raise_guardrail_blocked_exception(response) + assert result is True + + @pytest.mark.asyncio + async def test_should_handle_null_assessments_list(self): + """assessments itself is null — should return False.""" + guardrail = self._create_guardrail() + + response = { + "action": "GUARDRAIL_INTERVENED", + "assessments": None, # null from Bedrock API + } + + result = guardrail._should_raise_guardrail_blocked_exception(response) + assert result is False + + @pytest.mark.asyncio + async def test_should_handle_null_topics_with_blocked_word_policy(self): + """topics is null but wordPolicy has a BLOCKED customWord.""" + guardrail = self._create_guardrail() + + response = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "topicPolicy": { + "topics": None, + }, + "wordPolicy": { + "customWords": [ + {"match": "badword", "action": "BLOCKED"} + ], + "managedWordLists": None, + }, + } + ], + } + + result = guardrail._should_raise_guardrail_blocked_exception(response) + assert result is True + + @pytest.mark.asyncio + async def test_should_handle_null_pii_with_blocked_regex(self): + """piiEntities is null but regexes has a BLOCKED match.""" + guardrail = self._create_guardrail() + + response = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": None, + "regexes": [ + {"name": "SSN", "match": "123-45-6789", "action": "BLOCKED"} + ], + }, + } + ], + } + + result = guardrail._should_raise_guardrail_blocked_exception(response) + assert result is True + + @pytest.mark.asyncio + async def test_should_handle_null_grounding_filters(self): + """contextualGroundingPolicy.filters is null — should not crash.""" + guardrail = self._create_guardrail() + + response = { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "contextualGroundingPolicy": { + "filters": None, + }, + } + ], + } + + result = guardrail._should_raise_guardrail_blocked_exception(response) + assert result is False + + @pytest.mark.asyncio + async def test_should_not_crash_when_action_is_not_intervened(self): + """If action != GUARDRAIL_INTERVENED, null lists should never be reached.""" + guardrail = self._create_guardrail() + + response = { + "action": "NONE", + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": None, + "regexes": None, + }, + } + ], + } + + result = guardrail._should_raise_guardrail_blocked_exception(response) + assert result is False + + +class TestApplyGuardrailNullSafety: + """Tests for apply_guardrail handling of null/None texts input.""" + + @pytest.mark.asyncio + async def test_should_handle_none_texts_in_inputs(self): + """inputs[\"texts\"] is explicitly None — should not crash.""" + guardrail = BedrockGuardrail( + guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT" + ) + + inputs = {"texts": None} # Explicit None + + mock_credentials = MagicMock() + + with patch.object( + guardrail.async_handler, "post", new_callable=AsyncMock + ) as mock_post, patch.object( + guardrail, "_load_credentials", return_value=(mock_credentials, "us-east-1") + ), patch.object( + guardrail, "_prepare_request", return_value=MagicMock() + ): + # With empty texts (from None → []), no Bedrock API call should be made + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data={}, + input_type="request", + ) + + # Should return empty texts without crashing + assert result.get("texts") == [] + # No Bedrock API call should be made for empty input + mock_post.assert_not_called() + + @pytest.mark.asyncio + async def test_should_handle_missing_texts_key(self): + """inputs has no \"texts\" key at all — should not crash.""" + guardrail = BedrockGuardrail( + guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT" + ) + + inputs = {} # No "texts" key + + mock_credentials = MagicMock() + + with patch.object( + guardrail.async_handler, "post", new_callable=AsyncMock + ) as mock_post, patch.object( + guardrail, "_load_credentials", return_value=(mock_credentials, "us-east-1") + ), patch.object( + guardrail, "_prepare_request", return_value=MagicMock() + ): + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data={}, + input_type="request", + ) + + assert result.get("texts") == [] + mock_post.assert_not_called() + + + +@pytest.mark.asyncio +async def test_bedrock_guardrail_blocked_vs_anonymized_actions(): + """Test that BLOCKED actions raise exceptions but ANONYMIZED actions do not""" + guardrail = BedrockGuardrail( + guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT" + ) + + # Test 1: ANONYMIZED action should NOT raise exception + anonymized_response = { + "action": "GUARDRAIL_INTERVENED", + "outputs": [{"text": "Hello, my phone number is {PHONE}"}], + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": [ + { + "type": "PHONE", + "match": "+1 412 555 1212", + "action": "ANONYMIZED", + } + ] + } + } + ], + } + + should_raise = guardrail._should_raise_guardrail_blocked_exception( + anonymized_response + ) + assert should_raise is False, "ANONYMIZED actions should not raise exceptions" + + # Test 2: BLOCKED action should raise exception + blocked_response = { + "action": "GUARDRAIL_INTERVENED", + "outputs": [{"text": "I can't provide that information."}], + "assessments": [ + { + "topicPolicy": { + "topics": [ + {"name": "Sensitive Topic", "type": "DENY", "action": "BLOCKED"} + ] + } + } + ], + } + + should_raise = guardrail._should_raise_guardrail_blocked_exception(blocked_response) + assert should_raise is True, "BLOCKED actions should raise exceptions" + + # Test 3: Mixed actions - should raise if ANY action is BLOCKED + mixed_response = { + "action": "GUARDRAIL_INTERVENED", + "outputs": [{"text": "I can't provide that information."}], + "assessments": [ + { + "sensitiveInformationPolicy": { + "piiEntities": [ + { + "type": "PHONE", + "match": "+1 412 555 1212", + "action": "ANONYMIZED", + } + ] + }, + "topicPolicy": { + "topics": [ + {"name": "Blocked Topic", "type": "DENY", "action": "BLOCKED"} + ] + }, + } + ], + } + + should_raise = guardrail._should_raise_guardrail_blocked_exception(mixed_response) + assert ( + should_raise is True + ), "Mixed actions with any BLOCKED should raise exceptions" + + # Test 4: NONE action should not raise exception + none_response = { + "action": "NONE", + "outputs": [], + "assessments": [], + } + + should_raise = guardrail._should_raise_guardrail_blocked_exception(none_response) + assert should_raise is False, "NONE action should not raise exceptions" + + print("\u2705 BLOCKED vs ANONYMIZED actions test passed") + + # --------------------------------------------------------------------------- # L3: _extract_blocked_assessments + _get_http_exception_for_blocked_guardrail # Regression coverage for case 2026-04-10-internal-bedrock-guardrail-streaming-error. @@ -1338,4 +1815,3 @@ def test_get_http_exception_no_blocked_assessments_omits_field(): assert isinstance(exc, HTTPException) assert "assessments" not in exc.detail assert exc.detail["guardrailIdentifier"] == "amgllac6xf3r" - diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py index 109ad0bfdc..f2e7447239 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py @@ -4,6 +4,7 @@ import pytest from fastapi import HTTPException from litellm.integrations.custom_guardrail import ModifyResponseException +from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_hooks.grayswan.grayswan import ( GraySwanGuardrail, GraySwanGuardrailAPIError, @@ -435,7 +436,10 @@ def test_format_violation_message() -> None: assert "Gray Swan Cygnal Guardrail" in message assert "the input query has a violation score of 0.85" in message assert "violating the rule(s): 1, 3, 5" in message - assert "Mutation effort to make the harmful intention disguised was DETECTED" in message + assert ( + "Mutation effort to make the harmful intention disguised was DETECTED" + in message + ) # IPI should not be in message since it's False assert "Indirect Prompt Injection was DETECTED" not in message @@ -446,4 +450,58 @@ def test_format_violation_message() -> None: assert "Gray Swan Cygnal Guardrail" in message assert "the model response has a violation score of 0.85" in message assert "violating the rule(s): 1, 3, 5" in message - assert "Mutation effort to make the harmful intention disguised was DETECTED" in message + assert ( + "Mutation effort to make the harmful intention disguised was DETECTED" + in message + ) + + +def test_prepare_payload_includes_litellm_metadata( + grayswan_guardrail: GraySwanGuardrail, +) -> None: + """Verify _prepare_payload forwards litellm_metadata from request_data.""" + messages = [{"role": "user", "content": "hello"}] + request_data = { + "litellm_metadata": { + "user_api_key_user_id": "user-123", + "user_api_key_team_id": "team-456", + "user_api_key_spend": 0, + } + } + + payload = grayswan_guardrail._prepare_payload(messages, {}, request_data) + + assert payload is not None + assert "litellm_metadata" in payload + assert payload["litellm_metadata"]["user_api_key_user_id"] == "user-123" + assert payload["litellm_metadata"]["user_api_key_team_id"] == "team-456" + + +def test_ensure_litellm_metadata_populates_from_user_api_key_dict() -> None: + """Verify _ensure_litellm_metadata populates litellm_metadata.""" + from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( + _ensure_litellm_metadata, + ) + + user_auth = UserAPIKeyAuth(user_id="u1", team_id="t1", api_key="sk-test-hashed") + data: dict = {} + + _ensure_litellm_metadata(data, user_auth) + + assert "litellm_metadata" in data + assert data["litellm_metadata"]["user_api_key_user_id"] == "u1" + assert data["litellm_metadata"]["user_api_key_team_id"] == "t1" + + +def test_ensure_litellm_metadata_noop_when_already_present() -> None: + """Verify _ensure_litellm_metadata does not overwrite existing litellm_metadata.""" + from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( + _ensure_litellm_metadata, + ) + + user_auth = UserAPIKeyAuth(user_id="should-not-appear") + data: dict = {"litellm_metadata": {"existing": "value"}} + + _ensure_litellm_metadata(data, user_auth) + + assert data["litellm_metadata"] == {"existing": "value"} diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_noma_v2.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_noma_v2.py index d5fc1bdc69..7a3566fecb 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_noma_v2.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_noma_v2.py @@ -150,13 +150,19 @@ class TestNomaV2Configuration: application_id="dynamic-app", ) - payload["request_data"]["metadata"]["headers"]["x-noma-application-id"] = "mutated-value" + payload["request_data"]["metadata"]["headers"][ + "x-noma-application-id" + ] = "mutated-value" payload["request_data"]["messages"][0]["content"] = "changed-content" - assert request_data["metadata"]["headers"]["x-noma-application-id"] == "header-app" + assert ( + request_data["metadata"]["headers"]["x-noma-application-id"] == "header-app" + ) assert request_data["messages"][0]["content"] == "hello" - def test_build_scan_payload_passes_model_call_details_as_is(self, noma_v2_guardrail): + def test_build_scan_payload_passes_model_call_details_as_is( + self, noma_v2_guardrail + ): class _LoggingObj: def __init__(self) -> None: self.model_call_details = { @@ -193,7 +199,9 @@ class TestNomaV2Configuration: assert request_data["litellm_logging_obj"] == "" @pytest.mark.asyncio - async def test_call_noma_scan_sanitizes_response_model_dump_object(self, noma_v2_guardrail): + async def test_call_noma_scan_sanitizes_response_model_dump_object( + self, noma_v2_guardrail + ): import json class _FakeModelResponse: @@ -221,7 +229,9 @@ class TestNomaV2Configuration: json.dumps(sent_payload) assert sent_payload["request_data"]["response"]["id"] == "resp-1" - def test_sanitize_payload_for_transport_falls_back_to_safe_dumps(self, noma_v2_guardrail): + def test_sanitize_payload_for_transport_falls_back_to_safe_dumps( + self, noma_v2_guardrail + ): with patch( "litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.json.dumps", side_effect=TypeError("cannot serialize"), @@ -230,12 +240,16 @@ class TestNomaV2Configuration: "litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.safe_dumps", return_value='{"fallback": true}', ) as mock_safe_dumps: - sanitized = noma_v2_guardrail._sanitize_payload_for_transport({"inputs": {"texts": ["hello"]}}) + sanitized = noma_v2_guardrail._sanitize_payload_for_transport( + {"inputs": {"texts": ["hello"]}} + ) mock_safe_dumps.assert_called_once() assert sanitized == {"fallback": True} - def test_sanitize_payload_for_transport_logs_warning_when_payload_becomes_empty(self, noma_v2_guardrail): + def test_sanitize_payload_for_transport_logs_warning_when_payload_becomes_empty( + self, noma_v2_guardrail + ): with patch( "litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.safe_json_loads", return_value={}, @@ -243,14 +257,18 @@ class TestNomaV2Configuration: with patch( "litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.verbose_proxy_logger.warning" ) as mock_warning: - sanitized = noma_v2_guardrail._sanitize_payload_for_transport({"inputs": {"texts": ["hello"]}}) + sanitized = noma_v2_guardrail._sanitize_payload_for_transport( + {"inputs": {"texts": ["hello"]}} + ) assert sanitized == {} mock_warning.assert_called_once_with( "Noma v2 guardrail: payload serialization failed, falling back to empty payload" ) - def test_sanitize_payload_for_transport_logs_warning_on_non_dict_output(self, noma_v2_guardrail): + def test_sanitize_payload_for_transport_logs_warning_on_non_dict_output( + self, noma_v2_guardrail + ): with patch( "litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.safe_json_loads", return_value=["not-a-dict"], @@ -258,7 +276,9 @@ class TestNomaV2Configuration: with patch( "litellm.proxy.guardrails.guardrail_hooks.noma.noma_v2.verbose_proxy_logger.warning" ) as mock_warning: - sanitized = noma_v2_guardrail._sanitize_payload_for_transport({"inputs": {"texts": ["hello"]}}) + sanitized = noma_v2_guardrail._sanitize_payload_for_transport( + {"inputs": {"texts": ["hello"]}} + ) assert sanitized == {} mock_warning.assert_called_once_with( @@ -271,7 +291,9 @@ class TestNomaV2Configuration: class TestNomaV2ActionBehavior: - def test_resolve_action_from_response_raises_on_unknown_action(self, noma_v2_guardrail): + def test_resolve_action_from_response_raises_on_unknown_action( + self, noma_v2_guardrail + ): with pytest.raises(ValueError, match="missing valid action"): noma_v2_guardrail._resolve_action_from_response({"action": "INVALID"}) @@ -296,7 +318,9 @@ class TestNomaV2ActionBehavior: assert result == inputs @pytest.mark.asyncio - async def test_native_action_guardrail_intervened_updates_supported_fields(self, noma_v2_guardrail): + async def test_native_action_guardrail_intervened_updates_supported_fields( + self, noma_v2_guardrail + ): inputs = { "texts": ["Name: Jane"], "images": ["https://old.example/image.png"], @@ -322,7 +346,10 @@ class TestNomaV2ActionBehavior: { "id": "call_1", "type": "function", - "function": {"name": "new_tool", "arguments": '{"safe":"true"}'}, + "function": { + "name": "new_tool", + "arguments": '{"safe":"true"}', + }, } ], } @@ -336,7 +363,9 @@ class TestNomaV2ActionBehavior: assert result["texts"] == ["Name: *******"] assert result["images"] == ["https://new.example/image.png"] - assert result["tools"] == [{"type": "function", "function": {"name": "new_tool"}}] + assert result["tools"] == [ + {"type": "function", "function": {"name": "new_tool"}} + ] assert result["tool_calls"] == [ { "id": "call_1", @@ -367,7 +396,9 @@ class TestNomaV2ActionBehavior: assert exc_info.value.detail["details"]["blocked_reason"] == "blocked by policy" @pytest.mark.asyncio - async def test_intervened_without_modifications_returns_original_inputs(self, noma_v2_guardrail): + async def test_intervened_without_modifications_returns_original_inputs( + self, noma_v2_guardrail + ): inputs = {"texts": ["Name: Jane"]} with patch.object( noma_v2_guardrail, @@ -464,7 +495,9 @@ class TestNomaV2ApplicationIdResolution: assert payload["application_id"] == "dynamic-app" @pytest.mark.asyncio - async def test_apply_guardrail_uses_configured_application_id(self, noma_v2_guardrail): + async def test_apply_guardrail_uses_configured_application_id( + self, noma_v2_guardrail + ): call_mock = AsyncMock(return_value={"action": "NONE"}) with patch.object( noma_v2_guardrail, @@ -482,7 +515,86 @@ class TestNomaV2ApplicationIdResolution: assert payload["application_id"] == "test-app" @pytest.mark.asyncio - async def test_apply_guardrail_omits_application_id_when_not_explicit(self): + async def test_apply_guardrail_falls_back_to_key_alias_from_litellm_metadata( + self, noma_v2_guardrail + ): + """When no explicit application_id is set, fall back to user_api_key_alias + so that each API key gets its own application entry in the Noma dashboard.""" + noma_v2_guardrail.application_id = None + call_mock = AsyncMock(return_value={"action": "NONE"}) + request_data = { + "metadata": {}, + "litellm_metadata": {"user_api_key_alias": "test-key-alias"}, + } + with patch.object( + noma_v2_guardrail, + "get_guardrail_dynamic_request_body_params", + return_value={}, + ): + with patch.object(noma_v2_guardrail, "_call_noma_scan", call_mock): + await noma_v2_guardrail.apply_guardrail( + inputs={"texts": ["hello"]}, + request_data=request_data, + input_type="request", + ) + + payload = call_mock.call_args.kwargs["payload"] + assert payload["application_id"] == "test-key-alias" + + @pytest.mark.asyncio + async def test_apply_guardrail_falls_back_to_key_alias_from_metadata( + self, noma_v2_guardrail + ): + """user_api_key_alias in metadata (set by proxy_server.py) is also resolved.""" + noma_v2_guardrail.application_id = None + call_mock = AsyncMock(return_value={"action": "NONE"}) + request_data = { + "metadata": {"user_api_key_alias": "test-service-key"}, + } + with patch.object( + noma_v2_guardrail, + "get_guardrail_dynamic_request_body_params", + return_value={}, + ): + with patch.object(noma_v2_guardrail, "_call_noma_scan", call_mock): + await noma_v2_guardrail.apply_guardrail( + inputs={"texts": ["hello"]}, + request_data=request_data, + input_type="request", + ) + + payload = call_mock.call_args.kwargs["payload"] + assert payload["application_id"] == "test-service-key" + + @pytest.mark.asyncio + async def test_apply_guardrail_configured_application_id_takes_precedence_over_key_alias( + self, noma_v2_guardrail + ): + """Explicit application_id (config/env) wins over key_alias fallback.""" + call_mock = AsyncMock(return_value={"action": "NONE"}) + request_data = { + "metadata": {"user_api_key_alias": "should-not-be-used"}, + } + with patch.object( + noma_v2_guardrail, + "get_guardrail_dynamic_request_body_params", + return_value={}, + ): + with patch.object(noma_v2_guardrail, "_call_noma_scan", call_mock): + await noma_v2_guardrail.apply_guardrail( + inputs={"texts": ["hello"]}, + request_data=request_data, + input_type="request", + ) + + payload = call_mock.call_args.kwargs["payload"] + assert payload["application_id"] == "test-app" + + @pytest.mark.asyncio + async def test_apply_guardrail_omits_application_id_when_no_fallback_available( + self, + ): + """When nothing is set — no config, no dynamic params, no key alias — omit entirely.""" guardrail_no_config = NomaV2Guardrail( api_key="test-api-key", application_id=None, @@ -490,7 +602,6 @@ class TestNomaV2ApplicationIdResolution: event_hook="pre_call", default_on=True, ) - call_mock = AsyncMock(return_value={"action": "NONE"}) with patch.object( guardrail_no_config, @@ -506,26 +617,3 @@ class TestNomaV2ApplicationIdResolution: payload = call_mock.call_args.kwargs["payload"] assert "application_id" not in payload - - @pytest.mark.asyncio - async def test_apply_guardrail_ignores_request_metadata_application_id(self, noma_v2_guardrail): - noma_v2_guardrail.application_id = None - call_mock = AsyncMock(return_value={"action": "NONE"}) - request_data = { - "metadata": {"headers": {"x-noma-application-id": "header-app"}}, - "litellm_metadata": {"user_api_key_alias": "alias-app"}, - } - with patch.object( - noma_v2_guardrail, - "get_guardrail_dynamic_request_body_params", - return_value={}, - ): - with patch.object(noma_v2_guardrail, "_call_noma_scan", call_mock): - await noma_v2_guardrail.apply_guardrail( - inputs={"texts": ["hello"]}, - request_data=request_data, - input_type="request", - ) - - payload = call_mock.call_args.kwargs["payload"] - assert "application_id" not in payload diff --git a/tests/test_litellm/proxy/management_helpers/test_team_member_permission_checks.py b/tests/test_litellm/proxy/management_helpers/test_team_member_permission_checks.py index 6aa08dddd0..f7a3d310a3 100644 --- a/tests/test_litellm/proxy/management_helpers/test_team_member_permission_checks.py +++ b/tests/test_litellm/proxy/management_helpers/test_team_member_permission_checks.py @@ -8,7 +8,7 @@ sys.path.insert( 0, os.path.abspath("../../..") ) # Adds the parent directory to the system path -from litellm.proxy._types import KeyManagementRoutes, Member +from litellm.proxy._types import KeyManagementRoutes, Member, ProxyException from litellm.proxy.management_helpers.team_member_permission_checks import ( BASELINE_TEAM_MEMBER_PERMISSIONS, TeamMemberPermissionChecks, @@ -188,3 +188,74 @@ class TestGetDefaultTeamParam: assert _get_default_team_param("budget_duration") == "7d" assert _get_default_team_param("tpm_limit") == 1000 assert _get_default_team_param("rpm_limit") == 100 + + +class TestCanTeamMemberExecuteKeyManagementEndpoint: + @pytest.mark.asyncio + async def test_raises_when_user_not_in_keys_team(self, monkeypatch): + """Non-members should be blocked from team-scoped key management endpoints.""" + from litellm.proxy.management_endpoints import key_management_endpoints + from litellm.proxy.management_helpers import team_member_permission_checks as module + + async def _mock_get_team_object(**kwargs): + team = MagicMock() + team.team_id = "team-b" + team.team_member_permissions = ["/key/update"] + return team + + monkeypatch.setattr(module, "get_team_object", _mock_get_team_object) + monkeypatch.setattr(key_management_endpoints, "_get_user_in_team", lambda **kwargs: None) + + user_api_key_dict = MagicMock() + user_api_key_dict.user_role = "internal_user" + user_api_key_dict.user_id = "user-a" + user_api_key_dict.parent_otel_span = None + + existing_key_row = MagicMock() + existing_key_row.team_id = "team-b" + + with pytest.raises(ProxyException) as exc: + await TeamMemberPermissionChecks.can_team_member_execute_key_management_endpoint( + user_api_key_dict=user_api_key_dict, + route=KeyManagementRoutes.KEY_UPDATE, + prisma_client=MagicMock(), + user_api_key_cache=MagicMock(), + existing_key_row=existing_key_row, + ) + assert str(exc.value.code) == "401" + assert exc.value.type == "team_member_permission_error" + + @pytest.mark.asyncio + async def test_allows_team_admin_in_keys_team(self, monkeypatch): + """Team admins of the key's team should be allowed.""" + from litellm.proxy.management_endpoints import key_management_endpoints + from litellm.proxy.management_helpers import team_member_permission_checks as module + + async def _mock_get_team_object(**kwargs): + team = MagicMock() + team.team_id = "team-a" + team.team_member_permissions = ["/key/update"] + return team + + monkeypatch.setattr(module, "get_team_object", _mock_get_team_object) + monkeypatch.setattr( + key_management_endpoints, + "_get_user_in_team", + lambda **kwargs: Member(role="admin", user_id="user-a"), + ) + + user_api_key_dict = MagicMock() + user_api_key_dict.user_role = "internal_user" + user_api_key_dict.user_id = "user-a" + user_api_key_dict.parent_otel_span = None + + existing_key_row = MagicMock() + existing_key_row.team_id = "team-a" + + await TeamMemberPermissionChecks.can_team_member_execute_key_management_endpoint( + user_api_key_dict=user_api_key_dict, + route=KeyManagementRoutes.KEY_UPDATE, + prisma_client=MagicMock(), + user_api_key_cache=MagicMock(), + existing_key_row=existing_key_row, + ) diff --git a/tests/test_litellm/proxy/test_cors_config.py b/tests/test_litellm/proxy/test_cors_config.py new file mode 100644 index 0000000000..c654d266b7 --- /dev/null +++ b/tests/test_litellm/proxy/test_cors_config.py @@ -0,0 +1,141 @@ +""" +Tests for CORS configuration security fix. + +All tests import _get_cors_config directly from proxy_server so they exercise +real production code rather than a local mirror. +""" + +import pytest + + +def test_cors_wildcard_disables_credentials(): + """should disable credentials when LITELLM_CORS_ORIGINS is not set (defaults to wildcard).""" + from litellm.proxy.proxy_server import _get_cors_config + + origins, allow_credentials = _get_cors_config(cors_origins_env="") + assert origins == ["*"] + assert allow_credentials is False + + +def test_cors_empty_string_disables_credentials(): + """should disable credentials when LITELLM_CORS_ORIGINS is empty or whitespace.""" + from litellm.proxy.proxy_server import _get_cors_config + + for empty in ("", " ", "\t"): + origins, allow_credentials = _get_cors_config(cors_origins_env=empty) + assert origins == ["*"], f"Expected wildcard for input {repr(empty)}" + assert ( + allow_credentials is False + ), f"Expected no credentials for input {repr(empty)}" + + +def test_cors_single_specific_origin_enables_credentials(): + """should enable credentials when a single explicit origin is configured.""" + from litellm.proxy.proxy_server import _get_cors_config + + origins, allow_credentials = _get_cors_config( + cors_origins_env="https://admin.example.com" + ) + assert origins == ["https://admin.example.com"] + assert allow_credentials is True + + +def test_cors_multiple_specific_origins_enables_credentials(): + """should enable credentials and correctly parse comma-separated origins.""" + from litellm.proxy.proxy_server import _get_cors_config + + origins, allow_credentials = _get_cors_config( + cors_origins_env="https://app.example.com, https://admin.example.com, https://api.example.com" + ) + assert origins == [ + "https://app.example.com", + "https://admin.example.com", + "https://api.example.com", + ] + assert allow_credentials is True + + +def test_cors_wildcard_string_in_env_disables_credentials(): + """should disable credentials when LITELLM_CORS_ORIGINS is explicitly set to '*'.""" + from litellm.proxy.proxy_server import _get_cors_config + + origins, allow_credentials = _get_cors_config(cors_origins_env="*") + assert "*" in origins + assert allow_credentials is False + + +def test_cors_origins_strips_whitespace(): + """should strip surrounding whitespace from each origin entry.""" + from litellm.proxy.proxy_server import _get_cors_config + + origins, _ = _get_cors_config( + cors_origins_env=" https://a.com , https://b.com " + ) + assert origins == ["https://a.com", "https://b.com"] + + +def test_cors_origins_skips_blank_entries(): + """should skip blank entries caused by trailing/double commas.""" + from litellm.proxy.proxy_server import _get_cors_config + + origins, allow_credentials = _get_cors_config( + cors_origins_env="https://a.com,,https://b.com," + ) + assert origins == ["https://a.com", "https://b.com"] + assert allow_credentials is True + + +def test_cors_explicit_credentials_true_overrides_wildcard(): + """should enable credentials when LITELLM_CORS_ALLOW_CREDENTIALS=true even + if wildcard origins are in use (opt-in for existing deployments).""" + from litellm.proxy.proxy_server import _get_cors_config + + origins, allow_credentials = _get_cors_config( + cors_origins_env="", + cors_credentials_env="true", + ) + assert "*" in origins + assert allow_credentials is True + + +def test_cors_explicit_credentials_false_overrides_specific_origins(): + """should disable credentials when LITELLM_CORS_ALLOW_CREDENTIALS=false even + if specific origins are configured.""" + from litellm.proxy.proxy_server import _get_cors_config + + origins, allow_credentials = _get_cors_config( + cors_origins_env="https://admin.example.com", + cors_credentials_env="false", + ) + assert origins == ["https://admin.example.com"] + assert allow_credentials is False + + +def test_cors_explicit_credentials_case_insensitive(): + """should accept TRUE/FALSE case-insensitively for LITELLM_CORS_ALLOW_CREDENTIALS.""" + from litellm.proxy.proxy_server import _get_cors_config + + _, allow_true = _get_cors_config(cors_origins_env="", cors_credentials_env="TRUE") + _, allow_false = _get_cors_config( + cors_origins_env="https://x.com", cors_credentials_env="FALSE" + ) + assert allow_true is True + assert allow_false is False + + +def test_proxy_server_cors_invariant(): + """should verify that proxy_server module-level origins and allow_cors_credentials + are consistent — catches any future drift in the module-level call to _get_cors_config. + """ + import os + + import litellm.proxy.proxy_server as proxy_server + + if os.getenv("LITELLM_CORS_ALLOW_CREDENTIALS") is None: + assert proxy_server.allow_cors_credentials == ( + "*" not in proxy_server.origins + ), ( + f"Invariant broken: allow_cors_credentials={proxy_server.allow_cors_credentials} " + f"but origins={proxy_server.origins}. " + "When origins contains '*', allow_credentials must be False." + ) diff --git a/tests/test_litellm/proxy/test_health_check_functions.py b/tests/test_litellm/proxy/test_health_check_functions.py index 13d2131efa..8b23e526b6 100644 --- a/tests/test_litellm/proxy/test_health_check_functions.py +++ b/tests/test_litellm/proxy/test_health_check_functions.py @@ -406,12 +406,6 @@ async def test_save_background_health_checks_to_db_exception_handling(): @pytest.mark.asyncio async def test_get_all_latest_health_checks_with_model_id(mock_prisma): """Test get_all_latest_health_checks properly groups by model_id""" - # Create mock checks with same model_name but different model_id - mock_check1 = MagicMock() - mock_check1.model_id = "model-123" - mock_check1.model_name = "gpt-3.5-turbo" - mock_check1.checked_at = datetime.now(timezone.utc) - timedelta(minutes=10) - mock_check2 = MagicMock() mock_check2.model_id = "model-456" mock_check2.model_name = "gpt-3.5-turbo" @@ -424,7 +418,7 @@ async def test_get_all_latest_health_checks_with_model_id(mock_prisma): # Order by checked_at desc mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock( - return_value=[mock_check3, mock_check2, mock_check1] + return_value=[mock_check3, mock_check2] ) result = await mock_prisma.get_all_latest_health_checks() @@ -445,18 +439,13 @@ async def test_get_all_latest_health_checks_with_model_id(mock_prisma): @pytest.mark.asyncio async def test_get_all_latest_health_checks_without_model_id(mock_prisma): """Test get_all_latest_health_checks groups by model_name when model_id is None""" - mock_check1 = MagicMock() - mock_check1.model_id = None - mock_check1.model_name = "gpt-3.5-turbo" - mock_check1.checked_at = datetime.now(timezone.utc) - timedelta(minutes=10) - mock_check2 = MagicMock() mock_check2.model_id = None mock_check2.model_name = "gpt-3.5-turbo" mock_check2.checked_at = datetime.now(timezone.utc) - timedelta(minutes=1) # Latest mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock( - return_value=[mock_check2, mock_check1] + return_value=[mock_check2] ) result = await mock_prisma.get_all_latest_health_checks() @@ -467,6 +456,41 @@ async def test_get_all_latest_health_checks_without_model_id(mock_prisma): assert result[0].checked_at == mock_check2.checked_at # Latest +@pytest.mark.asyncio +async def test_get_all_latest_health_checks_same_name_with_and_without_model_id(mock_prisma): + """ + Same model_name can appear twice after DISTINCT ON: once keyed by (model_id, name) + and once by (NULL, name) — different Postgres groups than a single row with id. + """ + now = datetime.now(timezone.utc) + with_id = MagicMock() + with_id.model_id = "deployment-abc" + with_id.model_name = "gpt-4" + with_id.checked_at = now - timedelta(minutes=2) + + without_id = MagicMock() + without_id.model_id = None + without_id.model_name = "gpt-4" + without_id.checked_at = now - timedelta(minutes=1) + + mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock( + return_value=[without_id, with_id] + ) + + result = await mock_prisma.get_all_latest_health_checks() + + assert len(result) == 2 + names = {r.model_name for r in result} + assert names == {"gpt-4"} + ids = {r.model_id for r in result} + assert "deployment-abc" in ids + assert None in ids + + by_key = {(r.model_id, r.model_name): r for r in result} + assert by_key[("deployment-abc", "gpt-4")].checked_at == with_id.checked_at + assert by_key[(None, "gpt-4")].checked_at == without_id.checked_at + + @pytest.mark.asyncio async def test_perform_health_check_and_save_passes_model_id_to_perform_health_check(): """Test that _perform_health_check_and_save passes model_id to perform_health_check so health checks run by model id.""" diff --git a/tests/test_litellm/proxy/test_health_check_max_tokens.py b/tests/test_litellm/proxy/test_health_check_max_tokens.py index e26f7fb9f2..4d417b40b5 100644 --- a/tests/test_litellm/proxy/test_health_check_max_tokens.py +++ b/tests/test_litellm/proxy/test_health_check_max_tokens.py @@ -1,7 +1,10 @@ +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from litellm.proxy.health_check import _update_litellm_params_for_health_check + from litellm.litellm_core_utils.health_check_helpers import HealthCheckHelpers -from unittest.mock import AsyncMock, patch, MagicMock +from litellm.proxy import health_check as hc_module +from litellm.proxy.health_check import _update_litellm_params_for_health_check @pytest.mark.asyncio @@ -50,10 +53,13 @@ async def test_ahealth_check_wildcard_models_respects_max_tokens(): Test that ahealth_check_wildcard_models respects max_tokens if passed, otherwise defaults to 10. """ - with patch( - "litellm.litellm_core_utils.llm_request_utils.pick_cheapest_chat_models_from_llm_provider", - return_value=["gpt-4o-mini"], - ), patch("litellm.acompletion", new_callable=AsyncMock): + with ( + patch( + "litellm.litellm_core_utils.llm_request_utils.pick_cheapest_chat_models_from_llm_provider", + return_value=["gpt-4o-mini"], + ), + patch("litellm.acompletion", new_callable=AsyncMock), + ): # Test Case 1: No max_tokens passed, should default to 10 model_params = {} await HealthCheckHelpers.ahealth_check_wildcard_models( @@ -73,3 +79,50 @@ async def test_ahealth_check_wildcard_models_respects_max_tokens(): litellm_logging_obj=MagicMock(), ) assert model_params["max_tokens"] == 3 + + +@pytest.mark.asyncio +async def test_background_health_check_max_tokens_env_var(monkeypatch): + """ + Test that BACKGROUND_HEALTH_CHECK_MAX_TOKENS env var is used as global default + for explicit (non-wildcard) models. + """ + monkeypatch.setattr(hc_module, "BACKGROUND_HEALTH_CHECK_MAX_TOKENS", 10) + + model_info = {} + litellm_params = {"model": "azure/gpt-4"} + + updated_params = _update_litellm_params_for_health_check(model_info, litellm_params) + + assert updated_params["max_tokens"] == 10 + + +@pytest.mark.asyncio +async def test_per_model_overrides_global_env_var(monkeypatch): + """ + Test that per-model health_check_max_tokens takes priority over + BACKGROUND_HEALTH_CHECK_MAX_TOKENS env var. + """ + monkeypatch.setattr(hc_module, "BACKGROUND_HEALTH_CHECK_MAX_TOKENS", 10) + + model_info = {"health_check_max_tokens": 5} + litellm_params = {"model": "azure/gpt-4"} + + updated_params = _update_litellm_params_for_health_check(model_info, litellm_params) + + assert updated_params["max_tokens"] == 5 + + +@pytest.mark.asyncio +async def test_global_env_var_applies_to_wildcard_models(monkeypatch): + """ + Test that BACKGROUND_HEALTH_CHECK_MAX_TOKENS env var also applies to wildcard models. + """ + monkeypatch.setattr(hc_module, "BACKGROUND_HEALTH_CHECK_MAX_TOKENS", 15) + + model_info = {} + litellm_params = {"model": "openai/*"} + + updated_params = _update_litellm_params_for_health_check(model_info, litellm_params) + + assert updated_params["max_tokens"] == 15 diff --git a/tests/test_litellm/proxy/test_spend_log_cleanup.py b/tests/test_litellm/proxy/test_spend_log_cleanup.py index 3a01437908..4923d70a43 100644 --- a/tests/test_litellm/proxy/test_spend_log_cleanup.py +++ b/tests/test_litellm/proxy/test_spend_log_cleanup.py @@ -287,9 +287,48 @@ def test_string_retention_still_works(): general_settings={"maximum_spend_logs_retention_period": setting} ) assert cleaner._should_delete_spend_logs() is True, f"Failed for {setting}" - assert cleaner.retention_seconds == expected_seconds, ( - f"Expected {expected_seconds} for {setting}, got {cleaner.retention_seconds}" - ) + assert ( + cleaner.retention_seconds == expected_seconds + ), f"Expected {expected_seconds} for {setting}, got {cleaner.retention_seconds}" + + +@pytest.mark.asyncio +async def test_delete_old_logs_aborts_on_non_int_execute_raw_return(): + """should abort deletion loop immediately when execute_raw returns a non-int + (e.g. None or dict), preventing an infinite loop.""" + mock_prisma_client = MagicMock() + mock_db = MagicMock() + mock_db.execute_raw = AsyncMock(return_value=None) + mock_prisma_client.db = mock_db + + cleaner = SpendLogCleanup( + general_settings={"maximum_spend_logs_retention_period": "7d"} + ) + + cutoff_date = datetime.now(timezone.utc) - timedelta(days=7) + total_deleted = await cleaner._delete_old_logs(mock_prisma_client, cutoff_date) + + assert mock_db.execute_raw.call_count == 1 + assert total_deleted == 0 + + +@pytest.mark.asyncio +async def test_delete_old_logs_continues_on_valid_int_return(): + """should continue deletion loop across batches when execute_raw returns valid int counts.""" + mock_prisma_client = MagicMock() + mock_db = MagicMock() + mock_db.execute_raw = AsyncMock(side_effect=[500, 300, 0]) + mock_prisma_client.db = mock_db + + cleaner = SpendLogCleanup( + general_settings={"maximum_spend_logs_retention_period": "7d"} + ) + + cutoff_date = datetime.now(timezone.utc) - timedelta(days=7) + total_deleted = await cleaner._delete_old_logs(mock_prisma_client, cutoff_date) + + assert mock_db.execute_raw.call_count == 3 + assert total_deleted == 800 def test_cleanup_batch_size_env_var(monkeypatch): diff --git a/tests/test_litellm/test_completion_timeout_resolution.py b/tests/test_litellm/test_completion_timeout_resolution.py new file mode 100644 index 0000000000..b1ec32e623 --- /dev/null +++ b/tests/test_litellm/test_completion_timeout_resolution.py @@ -0,0 +1,145 @@ +"""Unit tests for litellm.litellm_core_utils.completion_timeout.CompletionTimeout.""" + +import os +import sys + +import httpx + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +) + +from litellm.litellm_core_utils.completion_timeout import CompletionTimeout +from litellm.utils import supports_httpx_timeout + + +def test_explicit_timeout_wins(): + assert ( + CompletionTimeout.resolve( + 12.5, + {"timeout": 99.0, "request_timeout": 88.0}, + "openai", + global_timeout=None, + supports_httpx_timeout=supports_httpx_timeout, + ) + == 12.5 + ) + + +def test_kwargs_timeout_when_param_none(): + assert ( + CompletionTimeout.resolve( + None, + {"timeout": 21.0}, + "azure_ai", + global_timeout=None, + supports_httpx_timeout=supports_httpx_timeout, + ) + == 21.0 + ) + + +def test_request_timeout_alias_in_kwargs(): + assert ( + CompletionTimeout.resolve( + None, + {"request_timeout": 33.0}, + "bedrock", + global_timeout=None, + supports_httpx_timeout=supports_httpx_timeout, + ) + == 33.0 + ) + + +def test_global_timeout_from_litellm_settings(): + assert ( + CompletionTimeout.resolve( + None, + {}, + "vertex_ai", + global_timeout=360.0, + supports_httpx_timeout=supports_httpx_timeout, + ) + == 360.0 + ) + + +def test_global_timeout_package_default_coerced_to_600_for_completion(): + """Package default 6000s → 600s for completion-only path.""" + assert ( + CompletionTimeout.resolve( + None, + {}, + "openai", + global_timeout=6000.0, + supports_httpx_timeout=supports_httpx_timeout, + ) + == 600.0 + ) + + +def test_explicit_request_timeout_6000_preserved(): + """Explicit deployment/request timeout must not be truncated by the package sentinel.""" + assert ( + CompletionTimeout.resolve( + None, + {"request_timeout": 6000.0}, + "openai", + global_timeout=None, + supports_httpx_timeout=supports_httpx_timeout, + ) + == 6000.0 + ) + + +def test_explicit_model_timeout_6000_preserved(): + assert ( + CompletionTimeout.resolve( + 6000.0, + {"timeout": 1.0, "request_timeout": 2.0}, + "openai", + global_timeout=None, + supports_httpx_timeout=supports_httpx_timeout, + ) + == 6000.0 + ) + + +def test_fallback_600_when_no_global_timeout(): + assert ( + CompletionTimeout.resolve( + None, + {}, + "azure_ai", + global_timeout=None, + supports_httpx_timeout=supports_httpx_timeout, + ) + == 600.0 + ) + + +def test_httpx_timeout_coerced_for_provider_without_httpx_timeout_support(): + t = httpx.Timeout(50.0, connect=2.0) + out = CompletionTimeout.resolve( + t, + {}, + "azure_ai", + global_timeout=None, + supports_httpx_timeout=supports_httpx_timeout, + ) + assert out == 50.0 + assert not isinstance(out, httpx.Timeout) + + +def test_httpx_timeout_preserved_for_openai(): + t = httpx.Timeout(40.0, connect=5.0) + out = CompletionTimeout.resolve( + t, + {}, + "openai", + global_timeout=None, + supports_httpx_timeout=supports_httpx_timeout, + ) + assert out is t + assert isinstance(out, httpx.Timeout) diff --git a/tests/test_litellm/test_cost_calculator.py b/tests/test_litellm/test_cost_calculator.py index 446316a02d..ebe175b250 100644 --- a/tests/test_litellm/test_cost_calculator.py +++ b/tests/test_litellm/test_cost_calculator.py @@ -1876,7 +1876,7 @@ def test_gemini_without_cache_tokens_details(): "promptTokensDetails": [ {"modality": "TEXT", "tokenCount": 6}, {"modality": "IMAGE", "tokenCount": 258}, - ] + ], # No cacheTokensDetails } } @@ -2014,3 +2014,27 @@ def test_additional_costs_only_for_azure_ai(): completion_tokens=50, ) assert result is None, "Vertex AI should have no additional costs" + + +def test_openrouter_gemini_3_1_flash_lite_preview_pricing(): + """ + Test that openrouter/google/gemini-3.1-flash-lite-preview has a pricing entry. + + Regression test for https://github.com/BerriAI/litellm/issues/25604 + + The model exists and is callable via OpenRouter, but was missing from + model_prices_and_context_window.json when other Gemini 3.x variants were present. + This caused ValueError: This model isn't mapped yet during router pre-call checks. + """ + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + model_name = "openrouter/google/gemini-3.1-flash-lite-preview" + model_info = litellm.model_cost.get(model_name) + + assert model_info is not None, f"Missing model pricing entry: {model_name}" + assert model_info["litellm_provider"] == "openrouter" + assert model_info["input_cost_per_token"] == 2.5e-07 + assert model_info["output_cost_per_token"] == 1.5e-06 + assert model_info["max_input_tokens"] == 1048576 + assert model_info["max_output_tokens"] == 65536 diff --git a/tests/test_litellm/test_model_param_helper.py b/tests/test_litellm/test_model_param_helper.py index c6e4b864a2..a62779aeab 100644 --- a/tests/test_litellm/test_model_param_helper.py +++ b/tests/test_litellm/test_model_param_helper.py @@ -31,3 +31,32 @@ def test_get_standard_logging_model_parameters_excludes_prompt_content(): assert "prompt" not in result assert "input" not in result assert result == {"temperature": 0.5} + + +def test_get_all_llm_api_params_includes_responses_api(): + """ + Regression guard for the Responses API cache-key bug: + Responses-API-only kwargs must be present in the cache-key allow-list, + otherwise Cache.get_cache_key() silently drops them and two requests + that differ only in (e.g.) `instructions` collide on the same key. + """ + all_params = ModelParamHelper._get_all_llm_api_params() + responses_only_params = { + "instructions", + "previous_response_id", + "reasoning", + "include", + "store", + "background", + "max_output_tokens", + "max_tool_calls", + "prompt_cache_key", + "prompt_cache_retention", + "context_management", + "conversation", + "safety_identifier", + } + missing = responses_only_params - all_params + assert ( + missing == set() + ), f"Responses-API kwargs missing from cache-key allow-list: {sorted(missing)}" diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/organizations/useOrganizations.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/organizations/useOrganizations.ts index 323270f436..0e7cd8342e 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/hooks/organizations/useOrganizations.ts +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/organizations/useOrganizations.ts @@ -3,7 +3,7 @@ import { Organization, organizationInfoCall, organizationListCall } from "@/comp import { useQuery, useQueryClient, UseQueryResult } from "@tanstack/react-query"; import { createQueryKeys } from "../common/queryKeysFactory"; -const organizationKeys = createQueryKeys("organizations"); +export const organizationKeys = createQueryKeys("organizations"); export const useOrganizations = (): UseQueryResult => { const { accessToken, userId, userRole } = useAuthorized(); return useQuery({ diff --git a/ui/litellm-dashboard/src/app/(dashboard)/teams/TeamsView.tsx b/ui/litellm-dashboard/src/app/(dashboard)/teams/TeamsView.tsx index fcad42d3a7..94a0e03304 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/teams/TeamsView.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/teams/TeamsView.tsx @@ -1,4 +1,6 @@ import React, { useState, useEffect } from "react"; +import { useQueryClient } from "@tanstack/react-query"; +import { organizationKeys } from "@/app/(dashboard)/hooks/organizations/useOrganizations"; import { teamDeleteCall, Organization } from "@/components/networking"; import { fetchTeams } from "@/components/common_components/fetch_teams"; import { Form } from "antd"; @@ -54,6 +56,7 @@ const TeamsView: React.FC = ({ organizations, premiumUser = false, }) => { + const queryClient = useQueryClient(); const [currentOrg, setCurrentOrg] = useState(null); const [showFilters, setShowFilters] = useState(false); const [filters, setFilters] = useState({ @@ -138,6 +141,7 @@ const TeamsView: React.FC = ({ try { await teamDeleteCall(accessToken, teamToDelete); + queryClient.invalidateQueries({ queryKey: organizationKeys.all }); // Successfully completed the deletion. Update the state to trigger a rerender. fetchTeams(accessToken, userID, userRole, currentOrg, setTeams); } catch (error) { diff --git a/ui/litellm-dashboard/src/app/(dashboard)/teams/components/modals/CreateTeamModal.tsx b/ui/litellm-dashboard/src/app/(dashboard)/teams/components/modals/CreateTeamModal.tsx index ecaa3c08a4..f42371f83a 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/teams/components/modals/CreateTeamModal.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/teams/components/modals/CreateTeamModal.tsx @@ -13,9 +13,11 @@ import AgentSelector from "@/components/agent_management/AgentSelector"; import PremiumLoggingSettings from "@/components/common_components/PremiumLoggingSettings"; import ModelAliasManager from "@/components/common_components/ModelAliasManager"; import React, { useEffect, useState } from "react"; +import { useQueryClient } from "@tanstack/react-query"; import NotificationsManager from "@/components/molecules/notifications_manager"; import { fetchMCPAccessGroups, getGuardrailsList, getPoliciesList, Organization, Team, teamCreateCall } from "@/components/networking"; import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { organizationKeys } from "@/app/(dashboard)/hooks/organizations/useOrganizations"; import MCPToolPermissions from "@/components/mcp_server_management/MCPToolPermissions"; interface ModelAliases { @@ -71,6 +73,7 @@ const CreateTeamModal = ({ setIsTeamModalVisible, }: CreateTeamModalProps) => { const { userId: userID, userRole, accessToken, premiumUser } = useAuthorized(); + const queryClient = useQueryClient(); const [form] = Form.useForm(); const [userModels, setUserModels] = useState([]); const [currentOrgForCreateTeam, setCurrentOrgForCreateTeam] = useState(null); @@ -273,6 +276,7 @@ const CreateTeamModal = ({ } const response: any = await teamCreateCall(accessToken, formValues); + queryClient.invalidateQueries({ queryKey: organizationKeys.all }); if (teams !== null) { setTeams([...teams, response]); } else { diff --git a/ui/litellm-dashboard/src/components/guardrails/guardrail_optional_params.tsx b/ui/litellm-dashboard/src/components/guardrails/guardrail_optional_params.tsx index 2e51980419..b352360e50 100644 --- a/ui/litellm-dashboard/src/components/guardrails/guardrail_optional_params.tsx +++ b/ui/litellm-dashboard/src/components/guardrails/guardrail_optional_params.tsx @@ -192,8 +192,8 @@ const GuardrailOptionalParams: React.FC = ({ ) : field.type === "bool" || field.type === "boolean" ? ( ) : field.type === "number" ? ( diff --git a/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.test.tsx b/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.test.tsx index 883e30e492..08df1fbea2 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.test.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.test.tsx @@ -89,7 +89,7 @@ describe("ToolTestPanel defaults", () => { expect(screen.getByLabelText("message")).toHaveValue(""); expect(screen.getByLabelText("attempts")).toHaveValue(0); expect(screen.getByLabelText("ratio")).toHaveValue(0.4); - expect(screen.getByDisplayValue("True")).toBeInTheDocument(); + expect(screen.getByTitle("True")).toBeInTheDocument(); const keywordsTextarea = screen.getByTestId("textarea-keywords"); expect(JSON.parse(keywordsTextarea.value)).toEqual([""]); diff --git a/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.tsx b/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.tsx index abd955f329..7a592785a4 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/ToolTestPanel.tsx @@ -1,7 +1,7 @@ import React from "react"; import { Button, TextInput } from "@tremor/react"; import { MCPTool, InputSchema, InputSchemaProperty } from "./types"; -import { Form, Tooltip } from "antd"; +import { Form, Select, Tooltip } from "antd"; import { InfoCircleOutlined } from "@ant-design/icons"; import NotificationsManager from "../molecules/notifications_manager"; @@ -480,14 +480,14 @@ export function ToolTestPanel({ )} {prop.type === "boolean" && ( - + True + False + )} {(prop.type === "object" || prop.type === "array") && ( diff --git a/ui/litellm-dashboard/src/components/organization/organization_view.test.tsx b/ui/litellm-dashboard/src/components/organization/organization_view.test.tsx index a54c7e1fba..8325bea0fc 100644 --- a/ui/litellm-dashboard/src/components/organization/organization_view.test.tsx +++ b/ui/litellm-dashboard/src/components/organization/organization_view.test.tsx @@ -1,14 +1,15 @@ import React from "react"; -import { render, screen, waitFor } from "@testing-library/react"; +import { screen, waitFor } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; -import { vi, test, expect } from "vitest"; +import { vi, test, expect, beforeEach } from "vitest"; +import { renderWithProviders } from "../../../tests/test-utils"; import OrganizationInfoView from "./organization_view"; +import { useOrganization } from "@/app/(dashboard)/hooks/organizations/useOrganizations"; -// Mock networking calls used by the component +// Mock networking calls used by the component's mutation handlers vi.mock("../networking", () => { return { __esModule: true, - organizationInfoCall: vi.fn(), organizationMemberAddCall: vi.fn(), organizationMemberUpdateCall: vi.fn(), organizationMemberDeleteCall: vi.fn(), @@ -16,6 +17,20 @@ vi.mock("../networking", () => { }; }); +// Mock the React Query hook the component now reads org data from. The component +// also imports organizationKeys (used inside mutation handlers for invalidation), +// so provide a stub shape here too. +vi.mock("@/app/(dashboard)/hooks/organizations/useOrganizations", () => ({ + useOrganization: vi.fn(), + organizationKeys: { + all: ["organizations"], + list: () => ["organizations", "list", { params: {} }], + detail: (id: string) => ["organizations", "detail", id], + }, +})); + +const mockUseOrganization = vi.mocked(useOrganization); + // Mock noisy/heavy child components to keep this test focused on render vi.mock("../object_permissions_view", () => ({ __esModule: true, @@ -81,11 +96,14 @@ const mockOrg = { metadata: null, }; -test("renders organization view after loading data", async () => { - const { organizationInfoCall } = await import("../networking"); - (organizationInfoCall as unknown as ReturnType).mockResolvedValueOnce(mockOrg); +beforeEach(() => { + mockUseOrganization.mockReset(); +}); - const { findAllByText } = render( +test("renders organization view after loading data", async () => { + mockUseOrganization.mockReturnValue({ data: mockOrg, isLoading: false } as any); + + const { findAllByText } = renderWithProviders( {}} @@ -103,11 +121,10 @@ test("renders organization view after loading data", async () => { }); test("should display empty state when organization has no members", async () => { - const { organizationInfoCall } = await import("../networking"); - (organizationInfoCall as unknown as ReturnType).mockResolvedValueOnce(mockOrg); + mockUseOrganization.mockReturnValue({ data: mockOrg, isLoading: false } as any); const user = userEvent.setup(); - render( + renderWithProviders( {}} @@ -131,14 +148,13 @@ test("should display empty state when organization has no members", async () => }); test("should display team aliases when teams are available", async () => { - const { organizationInfoCall } = await import("../networking"); const orgWithTeams = { ...mockOrg, teams: [{ team_id: "team_123" }, { team_id: "team_456" }], }; - (organizationInfoCall as unknown as ReturnType).mockResolvedValueOnce(orgWithTeams); + mockUseOrganization.mockReturnValue({ data: orgWithTeams, isLoading: false } as any); - render( + renderWithProviders( {}} @@ -157,7 +173,6 @@ test("should display team aliases when teams are available", async () => { }); test("should display team ID as fallback when alias is not found", async () => { - const { organizationInfoCall } = await import("../networking"); mockUseTeams.mockReturnValueOnce({ data: [ { @@ -171,9 +186,9 @@ test("should display team ID as fallback when alias is not found", async () => { ...mockOrg, teams: [{ team_id: "team_999" }], }; - (organizationInfoCall as unknown as ReturnType).mockResolvedValueOnce(orgWithUnknownTeam); + mockUseOrganization.mockReturnValue({ data: orgWithUnknownTeam, isLoading: false } as any); - render( + renderWithProviders( {}} diff --git a/ui/litellm-dashboard/src/components/organization/organization_view.tsx b/ui/litellm-dashboard/src/components/organization/organization_view.tsx index 4bbcc9f43b..128e99ae84 100644 --- a/ui/litellm-dashboard/src/components/organization/organization_view.tsx +++ b/ui/litellm-dashboard/src/components/organization/organization_view.tsx @@ -1,4 +1,6 @@ import { useTeams } from "@/app/(dashboard)/hooks/teams/useTeams"; +import { organizationKeys, useOrganization } from "@/app/(dashboard)/hooks/organizations/useOrganizations"; +import { useQueryClient } from "@tanstack/react-query"; import { formatNumberWithCommas, copyToClipboard as utilCopyToClipboard } from "@/utils/dataUtils"; import { createTeamAliasMap } from "@/utils/teamUtils"; import { ArrowLeftIcon } from "@heroicons/react/outline"; @@ -14,7 +16,7 @@ import { import { Button, Form, Input, Select, Tabs, Typography } from "antd"; import type { ColumnsType } from "antd/es/table"; import { CheckIcon, CopyIcon } from "lucide-react"; -import React, { useEffect, useMemo, useState } from "react"; +import React, { useMemo, useState } from "react"; import MemberTable from "../common_components/MemberTable"; import UserSearchModal from "../common_components/user_search_modal"; import MCPServerSelector from "../mcp_server_management/MCPServerSelector"; @@ -23,7 +25,6 @@ import NotificationsManager from "../molecules/notifications_manager"; import { Member, Organization, - organizationInfoCall, organizationMemberAddCall, organizationMemberDeleteCall, organizationMemberUpdateCall, @@ -53,8 +54,8 @@ const OrganizationInfoView: React.FC = ({ userModels, editOrg, }) => { - const [orgData, setOrgData] = useState(null); - const [loading, setLoading] = useState(true); + const queryClient = useQueryClient(); + const { data: orgData, isLoading: loading } = useOrganization(organizationId); const [form] = Form.useForm(); const [isEditing, setIsEditing] = useState(false); const [isAddMemberModalVisible, setIsAddMemberModalVisible] = useState(false); @@ -67,24 +68,6 @@ const OrganizationInfoView: React.FC = ({ const teamAliasMap = useMemo(() => createTeamAliasMap(teams), [teams]); - const fetchOrgInfo = async () => { - try { - setLoading(true); - if (!accessToken) return; - const response = await organizationInfoCall(accessToken, organizationId); - setOrgData(response); - } catch (error) { - NotificationsManager.fromBackend("Failed to load organization information"); - console.error("Error fetching organization info:", error); - } finally { - setLoading(false); - } - }; - - useEffect(() => { - fetchOrgInfo(); - }, [organizationId, accessToken]); - const handleMemberAdd = async (values: any) => { try { if (accessToken == null) { @@ -101,7 +84,7 @@ const OrganizationInfoView: React.FC = ({ NotificationsManager.success("Organization member added successfully"); setIsAddMemberModalVisible(false); form.resetFields(); - fetchOrgInfo(); + queryClient.invalidateQueries({ queryKey: organizationKeys.all }); } catch (error) { NotificationsManager.fromBackend("Failed to add organization member"); console.error("Error adding organization member:", error); @@ -122,7 +105,7 @@ const OrganizationInfoView: React.FC = ({ NotificationsManager.success("Organization member updated successfully"); setIsEditMemberModalVisible(false); form.resetFields(); - fetchOrgInfo(); + queryClient.invalidateQueries({ queryKey: organizationKeys.all }); } catch (error) { NotificationsManager.fromBackend("Failed to update organization member"); console.error("Error updating organization member:", error); @@ -137,7 +120,7 @@ const OrganizationInfoView: React.FC = ({ NotificationsManager.success("Organization member deleted successfully"); setIsEditMemberModalVisible(false); form.resetFields(); - fetchOrgInfo(); + queryClient.invalidateQueries({ queryKey: organizationKeys.all }); } catch (error) { NotificationsManager.fromBackend("Failed to delete organization member"); console.error("Error deleting organization member:", error); @@ -187,7 +170,7 @@ const OrganizationInfoView: React.FC = ({ NotificationsManager.success("Organization settings updated successfully"); setIsEditing(false); - fetchOrgInfo(); + queryClient.invalidateQueries({ queryKey: organizationKeys.all }); } catch (error) { NotificationsManager.fromBackend("Failed to update organization settings"); console.error("Error updating organization:", error); diff --git a/ui/litellm-dashboard/src/components/policies/index.test.tsx b/ui/litellm-dashboard/src/components/policies/index.test.tsx new file mode 100644 index 0000000000..4a33e8905a --- /dev/null +++ b/ui/litellm-dashboard/src/components/policies/index.test.tsx @@ -0,0 +1,220 @@ +import React from "react"; +import { screen, waitFor, within } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { renderWithProviders } from "../../../tests/test-utils"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import PoliciesPanel from "./index"; + +/** + * Ant Design's static Modal.confirm often does not run onOk in the real app (React 18+). + * In jsdom it may still run; we mock confirm as a no-op so the test fails until the panel + * uses a controlled DeleteResourceModal instead of Modal.confirm. + */ +vi.mock("antd", async (importOriginal) => { + const mod = await importOriginal(); + return { + ...mod, + Modal: Object.assign(mod.Modal, { + confirm: vi.fn(), + }), + }; +}); + +const EXPECTED_ATTACHMENT_ID = "att-11111111-2222-3333-4444-555555555555" as const; + +const networkingMocks = vi.hoisted(() => ({ + deletePolicyAttachmentCall: vi.fn().mockResolvedValue(undefined), + getPoliciesList: vi.fn().mockResolvedValue({ policies: [] }), + getPolicyAttachmentsList: vi.fn().mockResolvedValue({ + attachments: [ + { + attachment_id: "att-11111111-2222-3333-4444-555555555555", + policy_name: "test-policy", + scope: null, + teams: [], + keys: [], + models: [], + tags: [], + }, + ], + }), + getGuardrailsList: vi.fn().mockResolvedValue({ guardrails: [] }), + getPolicyInfo: vi.fn().mockResolvedValue({}), + deletePolicyCall: vi.fn().mockResolvedValue(undefined), + createPolicyCall: vi.fn(), + updatePolicyCall: vi.fn(), + createPolicyAttachmentCall: vi.fn(), + createGuardrailCall: vi.fn(), + enrichPolicyTemplate: vi.fn(), +})); + +vi.mock("../networking", () => ({ + ...networkingMocks, +})); + +vi.mock("./impact_popover", () => ({ + default: () =>