mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 16:48:54 +00:00
fix(proxy): strip LiteLLM policy tracking from OpenAI batch metadata (#28425)
* fix(proxy): strip LiteLLM policy tracking from OpenAI batch metadata Batch create was failing with `Invalid type for 'metadata.applied_policies': expected a string, but got an array instead` whenever a policy attachment matched the request. The policy engine helpers wrote `applied_policies`, `applied_guardrails`, and `policy_sources` into `data["metadata"]` unconditionally, and `/v1/batches` forwarded that dict straight to OpenAI, which only accepts string values. - Route proxy-internal tracking into `litellm_metadata` for batch/file routes via a shared `_get_or_create_proxy_metadata_bucket` helper. - Sanitize `data["metadata"]` in `create_batch` to drop known internal keys and non-string values before building the OpenAI request. - Cover both behaviors with unit + endpoint tests. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(proxy): merge metadata buckets for batch policy response headers Ensure get_logging_caching_headers reads both metadata and litellm_metadata so policy/guardrail headers are emitted on batch routes with user metadata, and log dropped non-string OpenAI metadata at debug level. Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -15,6 +15,9 @@ from litellm.batches.main import CancelBatchRequest, RetrieveBatchRequest
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
sanitize_openai_provider_metadata,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_headers,
|
||||
@@ -120,6 +123,9 @@ async def create_batch( # noqa: PLR0915
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or "openai"
|
||||
)
|
||||
if isinstance(data.get("metadata"), dict):
|
||||
data["metadata"] = sanitize_openai_provider_metadata(data["metadata"])
|
||||
|
||||
_create_batch_data = LiteLLMBatchCreateRequest(**data)
|
||||
|
||||
# Apply team-level batch output expiry enforcement
|
||||
|
||||
@@ -409,11 +409,15 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str,
|
||||
|
||||
|
||||
def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
|
||||
_metadata = request_data.get("metadata", None)
|
||||
if not _metadata:
|
||||
_metadata = request_data.get("litellm_metadata", None)
|
||||
if not isinstance(_metadata, dict):
|
||||
_metadata = {}
|
||||
_metadata: Dict = {}
|
||||
metadata_bucket = request_data.get("metadata")
|
||||
litellm_metadata_bucket = request_data.get("litellm_metadata")
|
||||
if isinstance(metadata_bucket, dict):
|
||||
_metadata.update(metadata_bucket)
|
||||
if isinstance(litellm_metadata_bucket, dict):
|
||||
# Batch/file routes store proxy tracking in litellm_metadata while
|
||||
# user-facing metadata stays in metadata; merge both for headers.
|
||||
_metadata.update(litellm_metadata_bucket)
|
||||
headers = {}
|
||||
if "applied_guardrails" in _metadata:
|
||||
headers["x-litellm-applied-guardrails"] = ",".join(
|
||||
@@ -452,19 +456,103 @@ def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
|
||||
return headers
|
||||
|
||||
|
||||
def get_metadata_variable_name_from_kwargs(
|
||||
kwargs: dict,
|
||||
) -> Literal["metadata", "litellm_metadata"]:
|
||||
"""
|
||||
Helper to return what the "metadata" field should be called in the request data
|
||||
|
||||
- New endpoints return `litellm_metadata`
|
||||
- Old endpoints return `metadata`
|
||||
|
||||
Context:
|
||||
- LiteLLM used `metadata` as an internal field for storing metadata
|
||||
- OpenAI then started using this field for their metadata
|
||||
- LiteLLM is now moving to using `litellm_metadata` for our metadata
|
||||
"""
|
||||
return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata"
|
||||
|
||||
|
||||
LITELLM_PROXY_INTERNAL_METADATA_KEYS = frozenset(
|
||||
{
|
||||
"applied_policies",
|
||||
"applied_guardrails",
|
||||
"policy_sources",
|
||||
"guardrails",
|
||||
"guardrail_config",
|
||||
"_guardrail_pipelines",
|
||||
"_pipeline_managed_guardrails",
|
||||
"disable_global_guardrails",
|
||||
"disable_global_guardrail",
|
||||
"opted_out_global_guardrails",
|
||||
"pillar_response_headers",
|
||||
"_pillar_response_headers_trusted",
|
||||
"pillar_flagged",
|
||||
"pillar_scanners",
|
||||
"pillar_evidence",
|
||||
"pillar_evidence_truncated",
|
||||
"pillar_session_id_response",
|
||||
"standard_logging_object",
|
||||
"proxy_server_request",
|
||||
"secret_fields",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_or_create_proxy_metadata_bucket(
|
||||
request_data: Dict,
|
||||
) -> tuple[Literal["metadata", "litellm_metadata"], dict]:
|
||||
"""
|
||||
Return the proxy-internal metadata bucket for this request.
|
||||
|
||||
Batch/file routes store proxy state in ``litellm_metadata`` so the OpenAI
|
||||
``metadata`` field can remain provider-safe (string values only).
|
||||
"""
|
||||
metadata_key = get_metadata_variable_name_from_kwargs(request_data)
|
||||
metadata_bucket = request_data.get(metadata_key)
|
||||
if not isinstance(metadata_bucket, dict):
|
||||
metadata_bucket = {}
|
||||
request_data[metadata_key] = metadata_bucket
|
||||
return metadata_key, metadata_bucket
|
||||
|
||||
|
||||
def sanitize_openai_provider_metadata(
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Keep only provider-safe OpenAI metadata entries (string keys -> string values).
|
||||
|
||||
Strips LiteLLM proxy-internal tracking fields that must not be forwarded to
|
||||
OpenAI batch/file APIs.
|
||||
"""
|
||||
if not metadata:
|
||||
return metadata
|
||||
sanitized: Dict[str, str] = {}
|
||||
for key, value in metadata.items():
|
||||
if key in LITELLM_PROXY_INTERNAL_METADATA_KEYS:
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
sanitized[key] = value
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"sanitize_openai_provider_metadata: dropping key %r with non-string value of type %s",
|
||||
key,
|
||||
type(value).__name__,
|
||||
)
|
||||
return sanitized or None
|
||||
|
||||
|
||||
def add_guardrail_to_applied_guardrails_header(
|
||||
request_data: Dict, guardrail_name: Optional[str]
|
||||
):
|
||||
if guardrail_name is None:
|
||||
return
|
||||
_metadata = request_data.get("metadata", None) or {}
|
||||
_, _metadata = _get_or_create_proxy_metadata_bucket(request_data)
|
||||
if "applied_guardrails" in _metadata:
|
||||
if guardrail_name not in _metadata["applied_guardrails"]:
|
||||
_metadata["applied_guardrails"].append(guardrail_name)
|
||||
else:
|
||||
_metadata["applied_guardrails"] = [guardrail_name]
|
||||
# Ensure metadata is set back to request_data (important when metadata didn't exist)
|
||||
request_data["metadata"] = _metadata
|
||||
|
||||
|
||||
def add_policy_to_applied_policies_header(
|
||||
@@ -478,14 +566,12 @@ def add_policy_to_applied_policies_header(
|
||||
"""
|
||||
if policy_name is None:
|
||||
return
|
||||
_metadata = request_data.get("metadata", None) or {}
|
||||
_, _metadata = _get_or_create_proxy_metadata_bucket(request_data)
|
||||
if "applied_policies" in _metadata:
|
||||
if policy_name not in _metadata["applied_policies"]:
|
||||
_metadata["applied_policies"].append(policy_name)
|
||||
else:
|
||||
_metadata["applied_policies"] = [policy_name]
|
||||
# Ensure metadata is set back to request_data (important when metadata didn't exist)
|
||||
request_data["metadata"] = _metadata
|
||||
|
||||
|
||||
def add_policy_sources_to_metadata(request_data: Dict, policy_sources: Dict[str, str]):
|
||||
@@ -498,13 +584,12 @@ def add_policy_sources_to_metadata(request_data: Dict, policy_sources: Dict[str,
|
||||
"""
|
||||
if not policy_sources:
|
||||
return
|
||||
_metadata = request_data.get("metadata", None) or {}
|
||||
_, _metadata = _get_or_create_proxy_metadata_bucket(request_data)
|
||||
existing = _metadata.get("policy_sources", {})
|
||||
if not isinstance(existing, dict):
|
||||
existing = {}
|
||||
existing.update(policy_sources)
|
||||
_metadata["policy_sources"] = existing
|
||||
request_data["metadata"] = _metadata
|
||||
|
||||
|
||||
def add_guardrail_response_to_standard_logging_object(
|
||||
@@ -527,23 +612,6 @@ def add_guardrail_response_to_standard_logging_object(
|
||||
return standard_logging_object
|
||||
|
||||
|
||||
def get_metadata_variable_name_from_kwargs(
|
||||
kwargs: dict,
|
||||
) -> Literal["metadata", "litellm_metadata"]:
|
||||
"""
|
||||
Helper to return what the "metadata" field should be called in the request data
|
||||
|
||||
- New endpoints return `litellm_metadata`
|
||||
- Old endpoints return `metadata`
|
||||
|
||||
Context:
|
||||
- LiteLLM used `metadata` as an internal field for storing metadata
|
||||
- OpenAI then started using this field for their metadata
|
||||
- LiteLLM is now moving to using `litellm_metadata` for our metadata
|
||||
"""
|
||||
return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata"
|
||||
|
||||
|
||||
def process_callback(
|
||||
_callback: str, callback_type: str, environment_variables: dict
|
||||
) -> dict:
|
||||
|
||||
@@ -8,11 +8,14 @@ sys.path.insert(
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_policy_to_applied_policies_header,
|
||||
decrypt_callback_vars,
|
||||
encrypt_callback_vars,
|
||||
get_logging_caching_headers,
|
||||
initialize_callbacks_on_proxy,
|
||||
get_remaining_tokens_and_requests_from_request_data,
|
||||
normalize_callback_names,
|
||||
sanitize_openai_provider_metadata,
|
||||
)
|
||||
import litellm
|
||||
|
||||
@@ -92,6 +95,50 @@ def test_normalize_callback_names_lowercases_strings():
|
||||
]
|
||||
|
||||
|
||||
def test_add_policy_to_applied_policies_header_uses_litellm_metadata_bucket():
|
||||
request_data = {
|
||||
"input_file_id": "file-abc123",
|
||||
"litellm_metadata": {},
|
||||
}
|
||||
|
||||
add_policy_to_applied_policies_header(
|
||||
request_data=request_data, policy_name="global-baseline"
|
||||
)
|
||||
|
||||
assert request_data["litellm_metadata"]["applied_policies"] == ["global-baseline"]
|
||||
assert "applied_policies" not in request_data.get("metadata", {})
|
||||
|
||||
|
||||
def test_sanitize_openai_provider_metadata_strips_internal_tracking_fields():
|
||||
metadata = {
|
||||
"customer_id": "cust-123",
|
||||
"applied_policies": ["global-baseline"],
|
||||
"applied_guardrails": ["pii_blocker"],
|
||||
"note": 42,
|
||||
}
|
||||
|
||||
sanitized = sanitize_openai_provider_metadata(metadata)
|
||||
|
||||
assert sanitized == {"customer_id": "cust-123"}
|
||||
|
||||
|
||||
def test_get_logging_caching_headers_merges_metadata_and_litellm_metadata():
|
||||
request_data = {
|
||||
"metadata": {"customer_id": "cust-123"},
|
||||
"litellm_metadata": {
|
||||
"applied_policies": ["global-baseline"],
|
||||
"applied_guardrails": ["pii_blocker"],
|
||||
"policy_sources": {"global-baseline": "team_default"},
|
||||
},
|
||||
}
|
||||
|
||||
headers = get_logging_caching_headers(request_data)
|
||||
|
||||
assert headers["x-litellm-applied-policies"] == "global-baseline"
|
||||
assert headers["x-litellm-applied-guardrails"] == "pii_blocker"
|
||||
assert headers["x-litellm-policy-sources"] == "global-baseline=team_default"
|
||||
|
||||
|
||||
def test_initialize_callbacks_on_proxy_instantiates_compression_interception(
|
||||
monkeypatch,
|
||||
):
|
||||
|
||||
@@ -178,6 +178,77 @@ class TestBatchEndpointTeamOverride:
|
||||
assert kwargs["output_expires_after"] == TEAM_EXPIRY
|
||||
|
||||
|
||||
class TestBatchEndpointPolicyMetadata:
|
||||
"""Batch create must not forward LiteLLM policy tracking via OpenAI metadata."""
|
||||
|
||||
def test_create_batch_does_not_forward_applied_policies_metadata(
|
||||
self, monkeypatch, llm_router
|
||||
):
|
||||
from litellm.proxy.policy_engine.attachment_registry import (
|
||||
get_attachment_registry,
|
||||
)
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
Policy,
|
||||
PolicyAttachment,
|
||||
PolicyGuardrails,
|
||||
)
|
||||
|
||||
policy_registry = get_policy_registry()
|
||||
policy_registry._policies = {
|
||||
"global-baseline": Policy(
|
||||
guardrails=PolicyGuardrails(add=["pii_blocker"]),
|
||||
),
|
||||
}
|
||||
policy_registry._initialized = True
|
||||
|
||||
attachment_registry = get_attachment_registry()
|
||||
attachment_registry._attachments = [
|
||||
PolicyAttachment(policy="global-baseline", scope="*"),
|
||||
]
|
||||
attachment_registry._initialized = True
|
||||
|
||||
_setup_proxy(monkeypatch, llm_router)
|
||||
|
||||
user_key = UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
team_alias="batch-team",
|
||||
key_alias="batch-key",
|
||||
)
|
||||
app.dependency_overrides[user_api_key_auth] = lambda: user_key
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
async def mock_acreate_batch(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return _make_batch_response()
|
||||
|
||||
monkeypatch.setattr(litellm, "acreate_batch", mock_acreate_batch)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
"/v1/batches",
|
||||
json={
|
||||
"input_file_id": "file-abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
},
|
||||
headers={"Authorization": "Bearer test-key"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
policy_registry._policies = {}
|
||||
policy_registry._initialized = False
|
||||
attachment_registry._attachments = []
|
||||
attachment_registry._initialized = False
|
||||
|
||||
assert captured_kwargs.get("metadata") in (None, {})
|
||||
assert (
|
||||
"global-baseline" in captured_kwargs["litellm_metadata"]["applied_policies"]
|
||||
)
|
||||
|
||||
|
||||
class TestBatchEndpointTeamValidation:
|
||||
"""Verify validation errors for malformed team metadata on batch endpoint."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user