fix(proxy): strip LiteLLM policy tracking from OpenAI batch metadata (#28425)

* fix(proxy): strip LiteLLM policy tracking from OpenAI batch metadata

Batch create was failing with `Invalid type for 'metadata.applied_policies':
expected a string, but got an array instead` whenever a policy attachment
matched the request. The policy engine helpers wrote `applied_policies`,
`applied_guardrails`, and `policy_sources` into `data["metadata"]`
unconditionally, and `/v1/batches` forwarded that dict straight to OpenAI,
which only accepts string values.

- Route proxy-internal tracking into `litellm_metadata` for batch/file
  routes via a shared `_get_or_create_proxy_metadata_bucket` helper.
- Sanitize `data["metadata"]` in `create_batch` to drop known internal
  keys and non-string values before building the OpenAI request.
- Cover both behaviors with unit + endpoint tests.

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(proxy): merge metadata buckets for batch policy response headers

Ensure get_logging_caching_headers reads both metadata and litellm_metadata so policy/guardrail headers are emitted on batch routes with user metadata, and log dropped non-string OpenAI metadata at debug level.

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Shivam Rawat
2026-05-26 11:35:42 -07:00
committed by GitHub
parent 533eab4dbd
commit fbff60e9d9
4 changed files with 222 additions and 30 deletions
@@ -15,6 +15,9 @@ from litellm.batches.main import CancelBatchRequest, RetrieveBatchRequest
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.callback_utils import (
sanitize_openai_provider_metadata,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_headers,
@@ -120,6 +123,9 @@ async def create_batch( # noqa: PLR0915
or get_custom_llm_provider_from_request_headers(request=request)
or "openai"
)
if isinstance(data.get("metadata"), dict):
data["metadata"] = sanitize_openai_provider_metadata(data["metadata"])
_create_batch_data = LiteLLMBatchCreateRequest(**data)
# Apply team-level batch output expiry enforcement
+98 -30
View File
@@ -409,11 +409,15 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str,
def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
_metadata = request_data.get("metadata", None)
if not _metadata:
_metadata = request_data.get("litellm_metadata", None)
if not isinstance(_metadata, dict):
_metadata = {}
_metadata: Dict = {}
metadata_bucket = request_data.get("metadata")
litellm_metadata_bucket = request_data.get("litellm_metadata")
if isinstance(metadata_bucket, dict):
_metadata.update(metadata_bucket)
if isinstance(litellm_metadata_bucket, dict):
# Batch/file routes store proxy tracking in litellm_metadata while
# user-facing metadata stays in metadata; merge both for headers.
_metadata.update(litellm_metadata_bucket)
headers = {}
if "applied_guardrails" in _metadata:
headers["x-litellm-applied-guardrails"] = ",".join(
@@ -452,19 +456,103 @@ def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
return headers
def get_metadata_variable_name_from_kwargs(
kwargs: dict,
) -> Literal["metadata", "litellm_metadata"]:
"""
Helper to return what the "metadata" field should be called in the request data
- New endpoints return `litellm_metadata`
- Old endpoints return `metadata`
Context:
- LiteLLM used `metadata` as an internal field for storing metadata
- OpenAI then started using this field for their metadata
- LiteLLM is now moving to using `litellm_metadata` for our metadata
"""
return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata"
LITELLM_PROXY_INTERNAL_METADATA_KEYS = frozenset(
{
"applied_policies",
"applied_guardrails",
"policy_sources",
"guardrails",
"guardrail_config",
"_guardrail_pipelines",
"_pipeline_managed_guardrails",
"disable_global_guardrails",
"disable_global_guardrail",
"opted_out_global_guardrails",
"pillar_response_headers",
"_pillar_response_headers_trusted",
"pillar_flagged",
"pillar_scanners",
"pillar_evidence",
"pillar_evidence_truncated",
"pillar_session_id_response",
"standard_logging_object",
"proxy_server_request",
"secret_fields",
}
)
def _get_or_create_proxy_metadata_bucket(
request_data: Dict,
) -> tuple[Literal["metadata", "litellm_metadata"], dict]:
"""
Return the proxy-internal metadata bucket for this request.
Batch/file routes store proxy state in ``litellm_metadata`` so the OpenAI
``metadata`` field can remain provider-safe (string values only).
"""
metadata_key = get_metadata_variable_name_from_kwargs(request_data)
metadata_bucket = request_data.get(metadata_key)
if not isinstance(metadata_bucket, dict):
metadata_bucket = {}
request_data[metadata_key] = metadata_bucket
return metadata_key, metadata_bucket
def sanitize_openai_provider_metadata(
metadata: Optional[Dict[str, Any]],
) -> Optional[Dict[str, str]]:
"""
Keep only provider-safe OpenAI metadata entries (string keys -> string values).
Strips LiteLLM proxy-internal tracking fields that must not be forwarded to
OpenAI batch/file APIs.
"""
if not metadata:
return metadata
sanitized: Dict[str, str] = {}
for key, value in metadata.items():
if key in LITELLM_PROXY_INTERNAL_METADATA_KEYS:
continue
if isinstance(value, str):
sanitized[key] = value
else:
verbose_proxy_logger.debug(
"sanitize_openai_provider_metadata: dropping key %r with non-string value of type %s",
key,
type(value).__name__,
)
return sanitized or None
def add_guardrail_to_applied_guardrails_header(
request_data: Dict, guardrail_name: Optional[str]
):
if guardrail_name is None:
return
_metadata = request_data.get("metadata", None) or {}
_, _metadata = _get_or_create_proxy_metadata_bucket(request_data)
if "applied_guardrails" in _metadata:
if guardrail_name not in _metadata["applied_guardrails"]:
_metadata["applied_guardrails"].append(guardrail_name)
else:
_metadata["applied_guardrails"] = [guardrail_name]
# Ensure metadata is set back to request_data (important when metadata didn't exist)
request_data["metadata"] = _metadata
def add_policy_to_applied_policies_header(
@@ -478,14 +566,12 @@ def add_policy_to_applied_policies_header(
"""
if policy_name is None:
return
_metadata = request_data.get("metadata", None) or {}
_, _metadata = _get_or_create_proxy_metadata_bucket(request_data)
if "applied_policies" in _metadata:
if policy_name not in _metadata["applied_policies"]:
_metadata["applied_policies"].append(policy_name)
else:
_metadata["applied_policies"] = [policy_name]
# Ensure metadata is set back to request_data (important when metadata didn't exist)
request_data["metadata"] = _metadata
def add_policy_sources_to_metadata(request_data: Dict, policy_sources: Dict[str, str]):
@@ -498,13 +584,12 @@ def add_policy_sources_to_metadata(request_data: Dict, policy_sources: Dict[str,
"""
if not policy_sources:
return
_metadata = request_data.get("metadata", None) or {}
_, _metadata = _get_or_create_proxy_metadata_bucket(request_data)
existing = _metadata.get("policy_sources", {})
if not isinstance(existing, dict):
existing = {}
existing.update(policy_sources)
_metadata["policy_sources"] = existing
request_data["metadata"] = _metadata
def add_guardrail_response_to_standard_logging_object(
@@ -527,23 +612,6 @@ def add_guardrail_response_to_standard_logging_object(
return standard_logging_object
def get_metadata_variable_name_from_kwargs(
kwargs: dict,
) -> Literal["metadata", "litellm_metadata"]:
"""
Helper to return what the "metadata" field should be called in the request data
- New endpoints return `litellm_metadata`
- Old endpoints return `metadata`
Context:
- LiteLLM used `metadata` as an internal field for storing metadata
- OpenAI then started using this field for their metadata
- LiteLLM is now moving to using `litellm_metadata` for our metadata
"""
return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata"
def process_callback(
_callback: str, callback_type: str, environment_variables: dict
) -> dict:
@@ -8,11 +8,14 @@ sys.path.insert(
) # Adds the parent directory to the system path
from litellm.proxy.common_utils.callback_utils import (
add_policy_to_applied_policies_header,
decrypt_callback_vars,
encrypt_callback_vars,
get_logging_caching_headers,
initialize_callbacks_on_proxy,
get_remaining_tokens_and_requests_from_request_data,
normalize_callback_names,
sanitize_openai_provider_metadata,
)
import litellm
@@ -92,6 +95,50 @@ def test_normalize_callback_names_lowercases_strings():
]
def test_add_policy_to_applied_policies_header_uses_litellm_metadata_bucket():
request_data = {
"input_file_id": "file-abc123",
"litellm_metadata": {},
}
add_policy_to_applied_policies_header(
request_data=request_data, policy_name="global-baseline"
)
assert request_data["litellm_metadata"]["applied_policies"] == ["global-baseline"]
assert "applied_policies" not in request_data.get("metadata", {})
def test_sanitize_openai_provider_metadata_strips_internal_tracking_fields():
metadata = {
"customer_id": "cust-123",
"applied_policies": ["global-baseline"],
"applied_guardrails": ["pii_blocker"],
"note": 42,
}
sanitized = sanitize_openai_provider_metadata(metadata)
assert sanitized == {"customer_id": "cust-123"}
def test_get_logging_caching_headers_merges_metadata_and_litellm_metadata():
request_data = {
"metadata": {"customer_id": "cust-123"},
"litellm_metadata": {
"applied_policies": ["global-baseline"],
"applied_guardrails": ["pii_blocker"],
"policy_sources": {"global-baseline": "team_default"},
},
}
headers = get_logging_caching_headers(request_data)
assert headers["x-litellm-applied-policies"] == "global-baseline"
assert headers["x-litellm-applied-guardrails"] == "pii_blocker"
assert headers["x-litellm-policy-sources"] == "global-baseline=team_default"
def test_initialize_callbacks_on_proxy_instantiates_compression_interception(
monkeypatch,
):
@@ -178,6 +178,77 @@ class TestBatchEndpointTeamOverride:
assert kwargs["output_expires_after"] == TEAM_EXPIRY
class TestBatchEndpointPolicyMetadata:
"""Batch create must not forward LiteLLM policy tracking via OpenAI metadata."""
def test_create_batch_does_not_forward_applied_policies_metadata(
self, monkeypatch, llm_router
):
from litellm.proxy.policy_engine.attachment_registry import (
get_attachment_registry,
)
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.types.proxy.policy_engine import (
Policy,
PolicyAttachment,
PolicyGuardrails,
)
policy_registry = get_policy_registry()
policy_registry._policies = {
"global-baseline": Policy(
guardrails=PolicyGuardrails(add=["pii_blocker"]),
),
}
policy_registry._initialized = True
attachment_registry = get_attachment_registry()
attachment_registry._attachments = [
PolicyAttachment(policy="global-baseline", scope="*"),
]
attachment_registry._initialized = True
_setup_proxy(monkeypatch, llm_router)
user_key = UserAPIKeyAuth(
api_key="test-key",
team_alias="batch-team",
key_alias="batch-key",
)
app.dependency_overrides[user_api_key_auth] = lambda: user_key
captured_kwargs = {}
async def mock_acreate_batch(**kwargs):
captured_kwargs.update(kwargs)
return _make_batch_response()
monkeypatch.setattr(litellm, "acreate_batch", mock_acreate_batch)
try:
response = client.post(
"/v1/batches",
json={
"input_file_id": "file-abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
},
headers={"Authorization": "Bearer test-key"},
)
assert response.status_code == 200
finally:
app.dependency_overrides.clear()
policy_registry._policies = {}
policy_registry._initialized = False
attachment_registry._attachments = []
attachment_registry._initialized = False
assert captured_kwargs.get("metadata") in (None, {})
assert (
"global-baseline" in captured_kwargs["litellm_metadata"]["applied_policies"]
)
class TestBatchEndpointTeamValidation:
"""Verify validation errors for malformed team metadata on batch endpoint."""