diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 166ef7a66d..8516570995 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -15,6 +15,9 @@ from litellm.batches.main import CancelBatchRequest, RetrieveBatchRequest from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing +from litellm.proxy.common_utils.callback_utils import ( + sanitize_openai_provider_metadata, +) from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.common_utils.openai_endpoint_utils import ( get_custom_llm_provider_from_request_headers, @@ -120,6 +123,9 @@ async def create_batch( # noqa: PLR0915 or get_custom_llm_provider_from_request_headers(request=request) or "openai" ) + if isinstance(data.get("metadata"), dict): + data["metadata"] = sanitize_openai_provider_metadata(data["metadata"]) + _create_batch_data = LiteLLMBatchCreateRequest(**data) # Apply team-level batch output expiry enforcement diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 4995752d44..c5b97db07a 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -409,11 +409,15 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]: - _metadata = request_data.get("metadata", None) - if not _metadata: - _metadata = request_data.get("litellm_metadata", None) - if not isinstance(_metadata, dict): - _metadata = {} + _metadata: Dict = {} + metadata_bucket = request_data.get("metadata") + litellm_metadata_bucket = request_data.get("litellm_metadata") + if isinstance(metadata_bucket, dict): + _metadata.update(metadata_bucket) + if isinstance(litellm_metadata_bucket, dict): + # Batch/file routes store proxy tracking in litellm_metadata while + # user-facing metadata stays in metadata; merge both for headers. + _metadata.update(litellm_metadata_bucket) headers = {} if "applied_guardrails" in _metadata: headers["x-litellm-applied-guardrails"] = ",".join( @@ -452,19 +456,103 @@ def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]: return headers +def get_metadata_variable_name_from_kwargs( + kwargs: dict, +) -> Literal["metadata", "litellm_metadata"]: + """ + Helper to return what the "metadata" field should be called in the request data + + - New endpoints return `litellm_metadata` + - Old endpoints return `metadata` + + Context: + - LiteLLM used `metadata` as an internal field for storing metadata + - OpenAI then started using this field for their metadata + - LiteLLM is now moving to using `litellm_metadata` for our metadata + """ + return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata" + + +LITELLM_PROXY_INTERNAL_METADATA_KEYS = frozenset( + { + "applied_policies", + "applied_guardrails", + "policy_sources", + "guardrails", + "guardrail_config", + "_guardrail_pipelines", + "_pipeline_managed_guardrails", + "disable_global_guardrails", + "disable_global_guardrail", + "opted_out_global_guardrails", + "pillar_response_headers", + "_pillar_response_headers_trusted", + "pillar_flagged", + "pillar_scanners", + "pillar_evidence", + "pillar_evidence_truncated", + "pillar_session_id_response", + "standard_logging_object", + "proxy_server_request", + "secret_fields", + } +) + + +def _get_or_create_proxy_metadata_bucket( + request_data: Dict, +) -> tuple[Literal["metadata", "litellm_metadata"], dict]: + """ + Return the proxy-internal metadata bucket for this request. + + Batch/file routes store proxy state in ``litellm_metadata`` so the OpenAI + ``metadata`` field can remain provider-safe (string values only). + """ + metadata_key = get_metadata_variable_name_from_kwargs(request_data) + metadata_bucket = request_data.get(metadata_key) + if not isinstance(metadata_bucket, dict): + metadata_bucket = {} + request_data[metadata_key] = metadata_bucket + return metadata_key, metadata_bucket + + +def sanitize_openai_provider_metadata( + metadata: Optional[Dict[str, Any]], +) -> Optional[Dict[str, str]]: + """ + Keep only provider-safe OpenAI metadata entries (string keys -> string values). + + Strips LiteLLM proxy-internal tracking fields that must not be forwarded to + OpenAI batch/file APIs. + """ + if not metadata: + return metadata + sanitized: Dict[str, str] = {} + for key, value in metadata.items(): + if key in LITELLM_PROXY_INTERNAL_METADATA_KEYS: + continue + if isinstance(value, str): + sanitized[key] = value + else: + verbose_proxy_logger.debug( + "sanitize_openai_provider_metadata: dropping key %r with non-string value of type %s", + key, + type(value).__name__, + ) + return sanitized or None + + def add_guardrail_to_applied_guardrails_header( request_data: Dict, guardrail_name: Optional[str] ): if guardrail_name is None: return - _metadata = request_data.get("metadata", None) or {} + _, _metadata = _get_or_create_proxy_metadata_bucket(request_data) if "applied_guardrails" in _metadata: if guardrail_name not in _metadata["applied_guardrails"]: _metadata["applied_guardrails"].append(guardrail_name) else: _metadata["applied_guardrails"] = [guardrail_name] - # Ensure metadata is set back to request_data (important when metadata didn't exist) - request_data["metadata"] = _metadata def add_policy_to_applied_policies_header( @@ -478,14 +566,12 @@ def add_policy_to_applied_policies_header( """ if policy_name is None: return - _metadata = request_data.get("metadata", None) or {} + _, _metadata = _get_or_create_proxy_metadata_bucket(request_data) if "applied_policies" in _metadata: if policy_name not in _metadata["applied_policies"]: _metadata["applied_policies"].append(policy_name) else: _metadata["applied_policies"] = [policy_name] - # Ensure metadata is set back to request_data (important when metadata didn't exist) - request_data["metadata"] = _metadata def add_policy_sources_to_metadata(request_data: Dict, policy_sources: Dict[str, str]): @@ -498,13 +584,12 @@ def add_policy_sources_to_metadata(request_data: Dict, policy_sources: Dict[str, """ if not policy_sources: return - _metadata = request_data.get("metadata", None) or {} + _, _metadata = _get_or_create_proxy_metadata_bucket(request_data) existing = _metadata.get("policy_sources", {}) if not isinstance(existing, dict): existing = {} existing.update(policy_sources) _metadata["policy_sources"] = existing - request_data["metadata"] = _metadata def add_guardrail_response_to_standard_logging_object( @@ -527,23 +612,6 @@ def add_guardrail_response_to_standard_logging_object( return standard_logging_object -def get_metadata_variable_name_from_kwargs( - kwargs: dict, -) -> Literal["metadata", "litellm_metadata"]: - """ - Helper to return what the "metadata" field should be called in the request data - - - New endpoints return `litellm_metadata` - - Old endpoints return `metadata` - - Context: - - LiteLLM used `metadata` as an internal field for storing metadata - - OpenAI then started using this field for their metadata - - LiteLLM is now moving to using `litellm_metadata` for our metadata - """ - return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata" - - def process_callback( _callback: str, callback_type: str, environment_variables: dict ) -> dict: diff --git a/tests/test_litellm/proxy/common_utils/test_callback_utils.py b/tests/test_litellm/proxy/common_utils/test_callback_utils.py index cb30970a34..d328d68dcd 100644 --- a/tests/test_litellm/proxy/common_utils/test_callback_utils.py +++ b/tests/test_litellm/proxy/common_utils/test_callback_utils.py @@ -8,11 +8,14 @@ sys.path.insert( ) # Adds the parent directory to the system path from litellm.proxy.common_utils.callback_utils import ( + add_policy_to_applied_policies_header, decrypt_callback_vars, encrypt_callback_vars, + get_logging_caching_headers, initialize_callbacks_on_proxy, get_remaining_tokens_and_requests_from_request_data, normalize_callback_names, + sanitize_openai_provider_metadata, ) import litellm @@ -92,6 +95,50 @@ def test_normalize_callback_names_lowercases_strings(): ] +def test_add_policy_to_applied_policies_header_uses_litellm_metadata_bucket(): + request_data = { + "input_file_id": "file-abc123", + "litellm_metadata": {}, + } + + add_policy_to_applied_policies_header( + request_data=request_data, policy_name="global-baseline" + ) + + assert request_data["litellm_metadata"]["applied_policies"] == ["global-baseline"] + assert "applied_policies" not in request_data.get("metadata", {}) + + +def test_sanitize_openai_provider_metadata_strips_internal_tracking_fields(): + metadata = { + "customer_id": "cust-123", + "applied_policies": ["global-baseline"], + "applied_guardrails": ["pii_blocker"], + "note": 42, + } + + sanitized = sanitize_openai_provider_metadata(metadata) + + assert sanitized == {"customer_id": "cust-123"} + + +def test_get_logging_caching_headers_merges_metadata_and_litellm_metadata(): + request_data = { + "metadata": {"customer_id": "cust-123"}, + "litellm_metadata": { + "applied_policies": ["global-baseline"], + "applied_guardrails": ["pii_blocker"], + "policy_sources": {"global-baseline": "team_default"}, + }, + } + + headers = get_logging_caching_headers(request_data) + + assert headers["x-litellm-applied-policies"] == "global-baseline" + assert headers["x-litellm-applied-guardrails"] == "pii_blocker" + assert headers["x-litellm-policy-sources"] == "global-baseline=team_default" + + def test_initialize_callbacks_on_proxy_instantiates_compression_interception( monkeypatch, ): diff --git a/tests/test_litellm/proxy/test_batch_expiry.py b/tests/test_litellm/proxy/test_batch_expiry.py index d63f278e71..38c4a71608 100644 --- a/tests/test_litellm/proxy/test_batch_expiry.py +++ b/tests/test_litellm/proxy/test_batch_expiry.py @@ -178,6 +178,77 @@ class TestBatchEndpointTeamOverride: assert kwargs["output_expires_after"] == TEAM_EXPIRY +class TestBatchEndpointPolicyMetadata: + """Batch create must not forward LiteLLM policy tracking via OpenAI metadata.""" + + def test_create_batch_does_not_forward_applied_policies_metadata( + self, monkeypatch, llm_router + ): + from litellm.proxy.policy_engine.attachment_registry import ( + get_attachment_registry, + ) + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + from litellm.types.proxy.policy_engine import ( + Policy, + PolicyAttachment, + PolicyGuardrails, + ) + + policy_registry = get_policy_registry() + policy_registry._policies = { + "global-baseline": Policy( + guardrails=PolicyGuardrails(add=["pii_blocker"]), + ), + } + policy_registry._initialized = True + + attachment_registry = get_attachment_registry() + attachment_registry._attachments = [ + PolicyAttachment(policy="global-baseline", scope="*"), + ] + attachment_registry._initialized = True + + _setup_proxy(monkeypatch, llm_router) + + user_key = UserAPIKeyAuth( + api_key="test-key", + team_alias="batch-team", + key_alias="batch-key", + ) + app.dependency_overrides[user_api_key_auth] = lambda: user_key + + captured_kwargs = {} + + async def mock_acreate_batch(**kwargs): + captured_kwargs.update(kwargs) + return _make_batch_response() + + monkeypatch.setattr(litellm, "acreate_batch", mock_acreate_batch) + + try: + response = client.post( + "/v1/batches", + json={ + "input_file_id": "file-abc123", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + }, + headers={"Authorization": "Bearer test-key"}, + ) + assert response.status_code == 200 + finally: + app.dependency_overrides.clear() + policy_registry._policies = {} + policy_registry._initialized = False + attachment_registry._attachments = [] + attachment_registry._initialized = False + + assert captured_kwargs.get("metadata") in (None, {}) + assert ( + "global-baseline" in captured_kwargs["litellm_metadata"]["applied_policies"] + ) + + class TestBatchEndpointTeamValidation: """Verify validation errors for malformed team metadata on batch endpoint."""