diff --git a/litellm/litellm_core_utils/redact_messages.py b/litellm/litellm_core_utils/redact_messages.py index f3f560b33b..dbc9cabdc7 100644 --- a/litellm/litellm_core_utils/redact_messages.py +++ b/litellm/litellm_core_utils/redact_messages.py @@ -60,6 +60,9 @@ def _redact_choice_content(choice): def _redact_responses_api_output(output_items): """Helper to redact ResponsesAPIResponse output items.""" for output_item in output_items: + if hasattr(output_item, "text"): + output_item.text = "redacted-by-litellm" + if hasattr(output_item, "content") and isinstance(output_item.content, list): for content_part in output_item.content: if hasattr(content_part, "text"): @@ -75,6 +78,28 @@ def _redact_responses_api_output(output_items): summary_item.text = "redacted-by-litellm" +def _redact_responses_api_output_dict(output_items, redacted_str: str): + """Helper to redact ResponsesAPIResponse output items in dict form.""" + for output_item in output_items: + if not isinstance(output_item, dict): + continue + + if "text" in output_item: + output_item["text"] = redacted_str + + if isinstance(output_item.get("content"), list): + for content_item in output_item["content"]: + if isinstance(content_item, dict) and "text" in content_item: + content_item["text"] = redacted_str + + if output_item.get("type") == "reasoning" and isinstance( + output_item.get("summary"), list + ): + for summary_item in output_item["summary"]: + if isinstance(summary_item, dict) and "text" in summary_item: + summary_item["text"] = redacted_str + + def _redact_standard_logging_object(model_call_details: dict): """Redact messages and response inside standard_logging_object if present.""" standard_logging_object = model_call_details.get("standard_logging_object") @@ -93,28 +118,11 @@ def _redact_standard_logging_object(model_call_details: dict): if isinstance(response, dict) and "output" in response: # ResponsesAPIResponse format - redact content in output items if isinstance(response.get("output"), list): - for output_item in response["output"]: - if isinstance(output_item, dict) and "content" in output_item: - if isinstance(output_item["content"], list): - for content_item in output_item["content"]: - if ( - isinstance(content_item, dict) - and "text" in content_item - ): - content_item["text"] = redacted_str + _redact_responses_api_output_dict(response["output"], redacted_str) elif isinstance(response, dict) and "choices" in response: # ModelResponse dict format - redact content in choices if isinstance(response.get("choices"), list): - for choice in response["choices"]: - if isinstance(choice, dict): - if "message" in choice and isinstance(choice["message"], dict): - choice["message"]["content"] = redacted_str - if "audio" in choice["message"]: - choice["message"]["audio"] = None - elif "delta" in choice and isinstance(choice["delta"], dict): - choice["delta"]["content"] = redacted_str - if "audio" in choice["delta"]: - choice["delta"]["audio"] = None + _redact_model_response_dict_choices(response["choices"], redacted_str) elif isinstance(response, str): standard_logging_object["response"] = redacted_str else: @@ -122,6 +130,29 @@ def _redact_standard_logging_object(model_call_details: dict): standard_logging_object["response"] = {"text": redacted_str} +def _redact_model_response_dict_choices(choices, redacted_str: str): + for choice in choices: + if isinstance(choice, dict): + if "message" in choice and isinstance(choice["message"], dict): + choice["message"]["content"] = redacted_str + if "reasoning_content" in choice["message"]: + choice["message"]["reasoning_content"] = redacted_str + if "thinking_blocks" in choice["message"]: + choice["message"]["thinking_blocks"] = None + if "audio" in choice["message"]: + choice["message"]["audio"] = None + elif "delta" in choice and isinstance(choice["delta"], dict): + choice["delta"]["content"] = redacted_str + if "reasoning_content" in choice["delta"]: + choice["delta"]["reasoning_content"] = redacted_str + if "thinking_blocks" in choice["delta"]: + choice["delta"]["thinking_blocks"] = None + if "audio" in choice["delta"]: + choice["delta"]["audio"] = None + else: + _redact_choice_content(choice) + + def perform_redaction(model_call_details: dict, result): """ Performs the actual redaction on the logging object and result. @@ -132,6 +163,7 @@ def perform_redaction(model_call_details: dict, result): ] model_call_details["prompt"] = "" model_call_details["input"] = "" + _redact_standard_logging_object(model_call_details) # Redact streaming response if ( @@ -171,30 +203,14 @@ def perform_redaction(model_call_details: dict, result): elif isinstance(_result, dict) and "choices" in _result: # Handle dict representation of ModelResponse (e.g., from model_dump()) if _result.get("choices") is not None: - for choice in _result["choices"]: - if isinstance(choice, dict): - if "message" in choice and isinstance(choice["message"], dict): - choice["message"]["content"] = "redacted-by-litellm" - if "reasoning_content" in choice["message"]: - choice["message"][ - "reasoning_content" - ] = "redacted-by-litellm" - if "thinking_blocks" in choice["message"]: - choice["message"]["thinking_blocks"] = None - if "audio" in choice["message"]: - choice["message"]["audio"] = None - elif "delta" in choice and isinstance(choice["delta"], dict): - choice["delta"]["content"] = "redacted-by-litellm" - if "reasoning_content" in choice["delta"]: - choice["delta"][ - "reasoning_content" - ] = "redacted-by-litellm" - if "thinking_blocks" in choice["delta"]: - choice["delta"]["thinking_blocks"] = None - if "audio" in choice["delta"]: - choice["delta"]["audio"] = None - else: - _redact_choice_content(choice) + _redact_model_response_dict_choices( + _result["choices"], "redacted-by-litellm" + ) + elif isinstance(_result, dict) and "output" in _result: + if isinstance(_result.get("output"), list): + _redact_responses_api_output_dict( + _result["output"], "redacted-by-litellm" + ) elif isinstance(_result, litellm.ResponsesAPIResponse): if hasattr(_result, "output"): _redact_responses_api_output(_result.output) diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 7ddd722a80..198d9503cb 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -14,6 +14,8 @@ from litellm.types.utils import ( blue_color_code = "\033[94m" reset_color_code = "\033[0m" +TRUSTED_PILLAR_RESPONSE_HEADERS_METADATA_KEY = "_pillar_response_headers_trusted" + if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging @@ -417,10 +419,19 @@ def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]: if "semantic-similarity" in _metadata: headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"]) + is_trusted_pillar_metadata = ( + _metadata.get(TRUSTED_PILLAR_RESPONSE_HEADERS_METADATA_KEY) is True + ) pillar_headers = _metadata.get("pillar_response_headers") - if isinstance(pillar_headers, dict): - headers.update(pillar_headers) - elif "pillar_flagged" in _metadata: + if is_trusted_pillar_metadata and isinstance(pillar_headers, dict): + headers.update( + { + key: str(value) + for key, value in pillar_headers.items() + if isinstance(key, str) and key.lower().startswith("x-pillar-") + } + ) + elif is_trusted_pillar_metadata and "pillar_flagged" in _metadata: headers["x-pillar-flagged"] = str(_metadata["pillar_flagged"]).lower() return headers diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index e5df26b7c2..bb1db3d62d 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -71,6 +71,7 @@ from litellm.types.utils import ( ) GUARDRAIL_NAME = "bedrock" +_BEDROCK_DYNAMIC_BODY_DENYLIST = frozenset({"content", "source"}) class GuardrailMessageFilterResult(NamedTuple): @@ -413,11 +414,18 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): ) api_key: Optional[str] = None if request_data: - bedrock_request_data.update( + dynamic_request_body_params = ( self.get_guardrail_dynamic_request_body_params( request_data=request_data ) ) + bedrock_request_data.update( + { + key: value + for key, value in dynamic_request_body_params.items() + if key not in _BEDROCK_DYNAMIC_BODY_DENYLIST + } + ) if request_data.get("api_key") is not None: api_key = request_data["api_key"] diff --git a/litellm/proxy/guardrails/guardrail_hooks/pillar/pillar.py b/litellm/proxy/guardrails/guardrail_hooks/pillar/pillar.py index 1b3f11e56f..c9c73053a0 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/pillar/pillar.py +++ b/litellm/proxy/guardrails/guardrail_hooks/pillar/pillar.py @@ -29,6 +29,7 @@ from litellm.llms.custom_httpx.http_handler import ( ) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.common_utils.callback_utils import ( + TRUSTED_PILLAR_RESPONSE_HEADERS_METADATA_KEY, add_guardrail_to_applied_guardrails_header, get_metadata_variable_name_from_kwargs, ) @@ -144,6 +145,7 @@ def build_pillar_response_headers(metadata_store: Dict[str, Any]) -> Dict[str, s if headers: metadata_store["pillar_response_headers"] = headers + metadata_store[TRUSTED_PILLAR_RESPONSE_HEADERS_METADATA_KEY] = True return headers diff --git a/litellm/proxy/hooks/key_management_event_hooks.py b/litellm/proxy/hooks/key_management_event_hooks.py index 5cdd9ddb4b..6bb9c1f507 100644 --- a/litellm/proxy/hooks/key_management_event_hooks.py +++ b/litellm/proxy/hooks/key_management_event_hooks.py @@ -41,6 +41,7 @@ class KeyManagementEventHooks: """ from litellm.proxy.management_helpers.audit_logs import ( create_audit_log_for_update, + get_audit_log_changed_by, ) from litellm.proxy.proxy_server import litellm_proxy_admin_name @@ -61,9 +62,11 @@ class KeyManagementEventHooks: request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.api_key, table_name=LitellmTableNames.KEY_TABLE_NAME, object_id=response.token_id or "", @@ -102,6 +105,7 @@ class KeyManagementEventHooks: """ from litellm.proxy.management_helpers.audit_logs import ( create_audit_log_for_update, + get_audit_log_changed_by, ) from litellm.proxy.proxy_server import litellm_proxy_admin_name @@ -117,9 +121,11 @@ class KeyManagementEventHooks: request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.api_key, table_name=LitellmTableNames.KEY_TABLE_NAME, object_id=data.key, @@ -140,6 +146,7 @@ class KeyManagementEventHooks: ): from litellm.proxy.management_helpers.audit_logs import ( create_audit_log_for_update, + get_audit_log_changed_by, ) from litellm.proxy.proxy_server import litellm_proxy_admin_name @@ -189,9 +196,11 @@ class KeyManagementEventHooks: request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.token, table_name=LitellmTableNames.KEY_TABLE_NAME, object_id=existing_key_row.token, @@ -220,6 +229,7 @@ class KeyManagementEventHooks: """ from litellm.proxy.management_helpers.audit_logs import ( create_audit_log_for_update, + get_audit_log_changed_by, ) from litellm.proxy.proxy_server import litellm_proxy_admin_name @@ -237,9 +247,11 @@ class KeyManagementEventHooks: request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.token, table_name=LitellmTableNames.KEY_TABLE_NAME, object_id=key.token, diff --git a/litellm/proxy/hooks/user_management_event_hooks.py b/litellm/proxy/hooks/user_management_event_hooks.py index 38623f9209..08fa8d4dfa 100644 --- a/litellm/proxy/hooks/user_management_event_hooks.py +++ b/litellm/proxy/hooks/user_management_event_hooks.py @@ -192,13 +192,19 @@ class UserManagementEventHooks: if not litellm.store_audit_logs: return + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) + await create_audit_log_for_update( request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.api_key, table_name=LitellmTableNames.USER_TABLE_NAME, object_id=user_id, diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 5804e3f8d9..e72c2e40b0 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -104,6 +104,112 @@ LITELLM_METADATA_ROUTES = ( "files", ) +_UNTRUSTED_ROOT_CONTROL_FIELDS = ( + "proxy_server_request", + "standard_logging_object", + "secret_fields", + "mock_response", + "mock_tool_calls", + "disable_global_guardrails", + "disable_global_guardrail", + "opted_out_global_guardrails", + "applied_guardrails", + "applied_policies", + "policy_sources", + "pillar_response_headers", + "_guardrail_pipelines", + "_pipeline_managed_guardrails", +) + +_UNTRUSTED_METADATA_CONTROL_FIELDS = ( + "disable_global_guardrails", + "disable_global_guardrail", + "opted_out_global_guardrails", + "pillar_response_headers", + "_pillar_response_headers_trusted", + "pillar_flagged", + "pillar_scanners", + "pillar_evidence", + "pillar_evidence_truncated", + "pillar_session_id_response", + "applied_guardrails", + "applied_policies", + "policy_sources", + "standard_logging_object", + "proxy_server_request", + "secret_fields", + "_guardrail_pipelines", + "_pipeline_managed_guardrails", +) + +_UNTRUSTED_REQUEST_HEADER_CONTROL_FIELDS = frozenset( + { + "litellm-disable-message-redaction", + } +) +_CLIENT_MOCK_CONTROL_FIELDS = frozenset({"mock_response", "mock_tool_calls"}) +_ALLOW_CLIENT_MOCK_RESPONSE_METADATA_KEY = "allow_client_mock_response" +_ALLOW_CLIENT_MESSAGE_REDACTION_OPT_OUT_METADATA_KEY = ( + "allow_client_message_redaction_opt_out" +) + + +def _strip_untrusted_request_header_controls( + headers: Any, + *, + allow_client_message_redaction_opt_out: bool = False, +) -> None: + if not isinstance(headers, dict): + return + + for header_name in list(headers.keys()): + if ( + isinstance(header_name, str) + and header_name.lower() in _UNTRUSTED_REQUEST_HEADER_CONTROL_FIELDS + ): + if allow_client_message_redaction_opt_out: + continue + headers.pop(header_name, None) + + +def _is_false_like(value: Any) -> bool: + if isinstance(value, bool): + return value is False + if isinstance(value, str): + return value.strip().lower() in {"false", "0", "no", "off"} + return False + + +def _key_or_team_metadata_flag_is_true( + user_api_key_dict: UserAPIKeyAuth, + metadata_key: str, +) -> bool: + for admin_metadata in (user_api_key_dict.metadata, user_api_key_dict.team_metadata): + if ( + isinstance(admin_metadata, dict) + and admin_metadata.get(metadata_key) is True + ): + return True + return False + + +def _key_or_team_allows_client_mock_response( + user_api_key_dict: UserAPIKeyAuth, +) -> bool: + return _key_or_team_metadata_flag_is_true( + user_api_key_dict=user_api_key_dict, + metadata_key=_ALLOW_CLIENT_MOCK_RESPONSE_METADATA_KEY, + ) + + +def _key_or_team_allows_client_message_redaction_opt_out( + user_api_key_dict: UserAPIKeyAuth, +) -> bool: + return _key_or_team_metadata_flag_is_true( + user_api_key_dict=user_api_key_dict, + metadata_key=_ALLOW_CLIENT_MESSAGE_REDACTION_OPT_OUT_METADATA_KEY, + ) + def _get_metadata_variable_name(request: Request) -> str: """ @@ -962,11 +1068,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915 # Strip internal-only keys from user input before the proxy sets its own. # These keys are injected by the proxy itself below — user-supplied values # must not be trusted. - for _internal_key in ( - "proxy_server_request", - "standard_logging_object", - "secret_fields", - ): + _allow_client_mock_response = _key_or_team_allows_client_mock_response( + user_api_key_dict + ) + _allow_client_message_redaction_opt_out = ( + _key_or_team_allows_client_message_redaction_opt_out(user_api_key_dict) + ) + for _internal_key in _UNTRUSTED_ROOT_CONTROL_FIELDS: + if _allow_client_mock_response and _internal_key in _CLIENT_MOCK_CONTROL_FIELDS: + continue data.pop(_internal_key, None) # Strip spoofable auth metadata from user-supplied metadata dict _user_metadata = data.get("metadata") @@ -1007,6 +1117,17 @@ async def add_litellm_data_to_request( # noqa: PLR0915 forward_llm_provider_auth_headers=forward_llm_auth, authenticated_with_header=authenticated_with_header, ) + _strip_untrusted_request_header_controls( + _headers, + allow_client_message_redaction_opt_out=_allow_client_message_redaction_opt_out, + ) + if ( + not _allow_client_message_redaction_opt_out + and litellm.turn_off_message_logging is True + and "turn_off_message_logging" in data + and _is_false_like(data["turn_off_message_logging"]) + ): + data.pop("turn_off_message_logging", None) verbose_proxy_logger.debug(f"Request Headers: {_headers}") verbose_proxy_logger.debug(f"Raw Headers: {_raw_headers}") @@ -1144,8 +1265,18 @@ async def add_litellm_data_to_request( # noqa: PLR0915 for _meta_key in ("metadata", "litellm_metadata"): _user_meta = data.get(_meta_key) if isinstance(_user_meta, dict): - _user_meta.pop("_pipeline_managed_guardrails", None) - for _k in [k for k in _user_meta if k.startswith("user_api_key_")]: + _strip_untrusted_request_header_controls( + _user_meta.get("headers"), + allow_client_message_redaction_opt_out=( + _allow_client_message_redaction_opt_out + ), + ) + for _k in [ + k + for k in _user_meta + if k.startswith("user_api_key_") + or k in _UNTRUSTED_METADATA_CONTROL_FIELDS + ]: _user_meta.pop(_k, None) # Strip caller-supplied routing/budget tags unless the admin has opted diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index c6d37ace4f..921d24da04 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -2069,6 +2069,9 @@ async def delete_user( litellm_proxy_admin_name, prisma_client, ) + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No db connected"}) @@ -2162,9 +2165,11 @@ async def delete_user( request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.api_key, table_name=LitellmTableNames.USER_TABLE_NAME, object_id=user_id, diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 8129fb0de5..d1f75e3706 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -5254,6 +5254,9 @@ async def block_key( proxy_logging_obj, user_api_key_cache, ) + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) if prisma_client is None: raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value)) @@ -5297,9 +5300,11 @@ async def block_key( request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.api_key, table_name=LitellmTableNames.KEY_TABLE_NAME, object_id=hashed_token, @@ -5363,6 +5368,9 @@ async def unblock_key( proxy_logging_obj, user_api_key_cache, ) + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) if prisma_client is None: raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value)) @@ -5406,9 +5414,11 @@ async def unblock_key( request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.api_key, table_name=LitellmTableNames.KEY_TABLE_NAME, object_id=hashed_token, @@ -5589,7 +5599,6 @@ async def test_key_logging( "content": "Hello, this is a test from litellm /key/health. No LLM API call was made for this", } ], - "mock_response": "test response", } data = await add_litellm_data_to_request( data=data, @@ -5598,6 +5607,7 @@ async def test_key_logging( general_settings=general_settings, request=request, ) + data["mock_response"] = "test response" await litellm.acompletion( **data ) # make mock completion call to trigger key based callbacks diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index fca08e591f..9c510a568e 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -56,6 +56,7 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import ( decrypt_value_helper, encrypt_value_helper, ) +from litellm.proxy.management_helpers.audit_logs import get_audit_log_changed_by router = APIRouter(prefix="/v1/mcp", tags=["mcp"]) @@ -2230,7 +2231,12 @@ if MCP_AVAILABLE: detail={"error": "Only proxy admins can create MCP toolsets."}, ) touched_by = ( - litellm_changed_by or user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME + get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME, + ) + or LITELLM_PROXY_ADMIN_NAME ) try: result = await create_mcp_toolset(prisma_client, payload, touched_by) @@ -2321,7 +2327,12 @@ if MCP_AVAILABLE: detail={"error": "Only proxy admins can update MCP toolsets."}, ) touched_by = ( - litellm_changed_by or user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME + get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME, + ) + or LITELLM_PROXY_ADMIN_NAME ) try: result = await update_mcp_toolset(prisma_client, payload, touched_by) diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index bcafd93a4c..61247ce5de 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -906,6 +906,9 @@ async def new_team( # noqa: PLR0915 prisma_client, user_api_key_cache, ) + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No db connected"}) @@ -1174,9 +1177,11 @@ async def new_team( # noqa: PLR0915 request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.api_key, table_name=LitellmTableNames.TEAM_TABLE_NAME, object_id=data.team_id, @@ -1214,7 +1219,10 @@ async def _create_team_update_audit_log( user_api_key_dict: User API key authentication details litellm_proxy_admin_name: Name of the proxy admin """ - from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + get_audit_log_changed_by, + ) _before_value = existing_team_row.json(exclude_none=True) _before_value = json.dumps(_before_value, default=str) @@ -1225,9 +1233,11 @@ async def _create_team_update_audit_log( request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.api_key, table_name=LitellmTableNames.TEAM_TABLE_NAME, object_id=team_id, @@ -3037,6 +3047,9 @@ async def delete_team( litellm_proxy_admin_name, prisma_client, ) + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No db connected"}) @@ -3096,9 +3109,11 @@ async def delete_team( request_data=LiteLLM_AuditLogs( id=str(uuid.uuid4()), updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, + changed_by=get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ), changed_by_api_key=user_api_key_dict.api_key, table_name=LitellmTableNames.TEAM_TABLE_NAME, object_id=team_id, diff --git a/litellm/proxy/management_helpers/audit_logs.py b/litellm/proxy/management_helpers/audit_logs.py index 7599e11bde..d3b225e6e4 100644 --- a/litellm/proxy/management_helpers/audit_logs.py +++ b/litellm/proxy/management_helpers/audit_logs.py @@ -21,6 +21,28 @@ from litellm.proxy._types import ( from litellm.types.utils import StandardAuditLogPayload _audit_log_callback_cache: Dict[str, CustomLogger] = {} +ALLOW_LITELLM_CHANGED_BY_HEADER_METADATA_KEY = "allow_litellm_changed_by_header" + + +def _allows_litellm_changed_by_header(user_api_key_dict: UserAPIKeyAuth) -> bool: + for admin_metadata in (user_api_key_dict.metadata, user_api_key_dict.team_metadata): + if ( + isinstance(admin_metadata, dict) + and admin_metadata.get(ALLOW_LITELLM_CHANGED_BY_HEADER_METADATA_KEY) is True + ): + return True + return False + + +def get_audit_log_changed_by( + *, + litellm_changed_by: Optional[str], + user_api_key_dict: UserAPIKeyAuth, + litellm_proxy_admin_name: Optional[str], +) -> Optional[str]: + if litellm_changed_by and _allows_litellm_changed_by_header(user_api_key_dict): + return litellm_changed_by + return user_api_key_dict.user_id or litellm_proxy_admin_name def _resolve_audit_log_callback(name: str) -> Optional[CustomLogger]: @@ -143,8 +165,10 @@ async def create_object_audit_log( if _store_audit_logs is not True: return - _changed_by = ( - litellm_changed_by or user_api_key_dict.user_id or litellm_proxy_admin_name + _changed_by = get_audit_log_changed_by( + litellm_changed_by=litellm_changed_by, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, ) await create_audit_log_for_update( diff --git a/tests/litellm_utils_tests/test_utils.py b/tests/litellm_utils_tests/test_utils.py index 20af6e1023..e4dbe5f9f3 100644 --- a/tests/litellm_utils_tests/test_utils.py +++ b/tests/litellm_utils_tests/test_utils.py @@ -812,7 +812,7 @@ def test_redact_msgs_from_logs_with_dynamic_params(): # Assert redaction occurred assert _redacted_response_obj.choices[0].message.content == "redacted-by-litellm" - # Test Case 3: standard_callback_dynamic_params does not override litellm.turn_off_message_logging + # Test Case 3: standard_callback_dynamic_params does not set turn_off_message_logging # since litellm.turn_off_message_logging is True redaction should occur standard_callback_dynamic_params = StandardCallbackDynamicParams() litellm_logging_obj.model_call_details["standard_callback_dynamic_params"] = ( diff --git a/tests/logging_callback_tests/test_logging_redaction_e2e_test.py b/tests/logging_callback_tests/test_logging_redaction_e2e_test.py index 63ef4bafbb..08d0abd272 100644 --- a/tests/logging_callback_tests/test_logging_redaction_e2e_test.py +++ b/tests/logging_callback_tests/test_logging_redaction_e2e_test.py @@ -13,11 +13,13 @@ import logging import time from unittest.mock import AsyncMock, patch +import httpx import pytest import litellm from litellm._logging import verbose_logger from litellm.integrations.custom_logger import CustomLogger +from litellm.responses.main import mock_responses_api_response from litellm.types.utils import StandardLoggingPayload @@ -126,17 +128,10 @@ async def test_redaction_responses_api(): test_custom_logger = TestCustomLogger(turn_off_message_logging=True) litellm.callbacks = [test_custom_logger] - # Mock a ResponsesAPIResponse-style response - mock_response = { - "output": [{"text": "This is a test response"}], - "model": "gpt-3.5-turbo", - "usage": {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}, - } - response = await litellm.aresponses( model="gpt-3.5-turbo", input="hi", - mock_response=mock_response, + mock_response="This is a test response", ) await asyncio.sleep(1) @@ -163,6 +158,7 @@ async def test_redaction_responses_api(): assert ( content_item["text"] == "redacted-by-litellm" ), f"Expected redacted text but got: {content_item['text']}" + assert "This is a test response" not in json.dumps(standard_logging_payload) print( "logged standard logging payload for ResponsesAPIResponse", json.dumps(standard_logging_payload, indent=2), @@ -176,29 +172,36 @@ async def test_redaction_responses_api_stream(): test_custom_logger = TestCustomLogger(turn_off_message_logging=True) litellm.callbacks = [test_custom_logger] - # Mock a ResponsesAPIResponse-style response with streaming chunks - mock_response = [ - { - "output": [{"text": "This"}], - "model": "gpt-3.5-turbo", - }, - { - "output": [{"text": " is"}], - "model": "gpt-3.5-turbo", - }, - { - "output": [{"text": " a test response"}], - "model": "gpt-3.5-turbo", - "usage": {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}, - }, - ] + mocked_response_payload = mock_responses_api_response( + "This is a test response" + ).model_dump() - response = await litellm.aresponses( - model="gpt-3.5-turbo", - input="hi", - mock_response=mock_response, - stream=True, - ) + async def mock_post(self, url, headers, timeout, stream=False, **kwargs): + stream_content = ( + "data: " + + json.dumps( + { + "type": "response.completed", + "response": mocked_response_payload, + } + ) + + "\n\ndata: [DONE]\n\n" + ) + return httpx.Response( + status_code=200, + content=stream_content, + request=httpx.Request("POST", url), + ) + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + new=mock_post, + ): + response = await litellm.aresponses( + model="gpt-3.5-turbo", + input="hi", + stream=True, + ) # Consume the stream chunks = [] @@ -445,18 +448,11 @@ async def test_disable_redaction_header_responses_api(): test_custom_logger = TestCustomLogger() litellm.callbacks = [test_custom_logger] - # Mock a ResponsesAPIResponse-style response - mock_response = { - "output": [{"text": "This is a test response"}], - "model": "gpt-3.5-turbo", - "usage": {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}, - } - # Pass the header via litellm_metadata (as the proxy does for Responses API) response = await litellm.aresponses( model="gpt-3.5-turbo", input="hi", - mock_response=mock_response, + mock_response="This is a test response", litellm_metadata={"headers": {"litellm-disable-message-redaction": "true"}}, ) @@ -464,14 +460,14 @@ async def test_disable_redaction_header_responses_api(): standard_logging_payload = test_custom_logger.logged_standard_logging_payload assert standard_logging_payload is not None - # Verify that messages are NOT redacted because the header was set + # Verify that the direct SDK path still honors the explicit header. print( "logged standard logging payload for ResponsesAPI with disable header", json.dumps(standard_logging_payload, indent=2, default=str), ) - # The content should NOT be redacted - assert standard_logging_payload["response"] != {"text": "redacted-by-litellm"} + response = standard_logging_payload["response"] + assert response["output"][0]["content"][0]["text"] == "This is a test response" assert standard_logging_payload["messages"][0]["content"] == "hi" diff --git a/tests/proxy_unit_tests/test_audit_logs_proxy.py b/tests/proxy_unit_tests/test_audit_logs_proxy.py index 7637229336..9e2b69176e 100644 --- a/tests/proxy_unit_tests/test_audit_logs_proxy.py +++ b/tests/proxy_unit_tests/test_audit_logs_proxy.py @@ -45,8 +45,11 @@ verbose_proxy_logger.setLevel(level=logging.DEBUG) from starlette.datastructures import URL -from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update -from litellm.proxy._types import LiteLLM_AuditLogs, LitellmTableNames +from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + get_audit_log_changed_by, +) +from litellm.proxy._types import LiteLLM_AuditLogs, LitellmTableNames, UserAPIKeyAuth from litellm.caching.caching import DualCache from unittest.mock import patch, AsyncMock @@ -54,6 +57,119 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) import json +def test_get_audit_log_changed_by_prefers_authenticated_user(): + user_api_key_dict = UserAPIKeyAuth( + api_key="test-key", + user_id="authenticated-user", + ) + + assert ( + get_audit_log_changed_by( + litellm_changed_by="spoofed-user", + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name="proxy-admin", + ) + == "authenticated-user" + ) + + +def test_get_audit_log_changed_by_honors_header_with_admin_opt_in(): + user_api_key_dict = UserAPIKeyAuth( + api_key="test-key", + user_id="service-account", + metadata={"allow_litellm_changed_by_header": True}, + ) + + assert ( + get_audit_log_changed_by( + litellm_changed_by="delegated-user", + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name="proxy-admin", + ) + == "delegated-user" + ) + + +def test_get_audit_log_changed_by_honors_header_with_team_opt_in(): + user_api_key_dict = UserAPIKeyAuth( + api_key="test-key", + user_id="service-account", + team_metadata={"allow_litellm_changed_by_header": True}, + ) + + assert ( + get_audit_log_changed_by( + litellm_changed_by="delegated-user", + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name="proxy-admin", + ) + == "delegated-user" + ) + + +def test_get_audit_log_changed_by_ignores_header_without_opt_in_when_user_id_missing(): + user_api_key_dict = UserAPIKeyAuth(api_key="test-key") + + assert ( + get_audit_log_changed_by( + litellm_changed_by="spoofed-user", + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name="proxy-admin", + ) + == "proxy-admin" + ) + + +def test_get_audit_log_changed_by_honors_header_with_opt_in_when_user_id_missing(): + user_api_key_dict = UserAPIKeyAuth( + api_key="test-key", + metadata={"allow_litellm_changed_by_header": True}, + ) + + assert ( + get_audit_log_changed_by( + litellm_changed_by="delegated-user", + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name="proxy-admin", + ) + == "delegated-user" + ) + + +@pytest.mark.asyncio +async def test_create_internal_user_audit_log_uses_changed_by_helper(): + from litellm.proxy.hooks.user_management_event_hooks import UserManagementEventHooks + + user_api_key_dict = UserAPIKeyAuth( + api_key="test-key", + user_id="service-account", + metadata={"allow_litellm_changed_by_header": True}, + ) + + with ( + patch("litellm.store_audit_logs", True), + patch( + "litellm.proxy.hooks.user_management_event_hooks.create_audit_log_for_update", + new_callable=AsyncMock, + ) as mock_create_audit_log_for_update, + ): + await UserManagementEventHooks.create_internal_user_audit_log( + user_id="target-user", + action="updated", + litellm_changed_by="delegated-user", + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name="proxy-admin", + before_value='{"before": true}', + after_value='{"after": true}', + ) + + request_data = mock_create_audit_log_for_update.await_args.kwargs["request_data"] + assert request_data.changed_by == "delegated-user" + assert request_data.changed_by_api_key == "test-key" + assert request_data.object_id == "target-user" + assert request_data.action == "updated" + + @pytest.mark.asyncio async def test_create_audit_log_for_update_premium_user(): """ diff --git a/tests/proxy_unit_tests/test_proxy_server.py b/tests/proxy_unit_tests/test_proxy_server.py index c62c3f41b3..86bbc5170e 100644 --- a/tests/proxy_unit_tests/test_proxy_server.py +++ b/tests/proxy_unit_tests/test_proxy_server.py @@ -1553,6 +1553,7 @@ async def test_add_callback_via_key(prisma_client): fastapi_response=Response(), user_api_key_dict=UserAPIKeyAuth( metadata={ + "allow_client_mock_response": True, "logging": [ { "callback_name": "langfuse", # 'otel', 'langfuse', 'lunary' @@ -1563,7 +1564,7 @@ async def test_add_callback_via_key(prisma_client): "langfuse_host": "https://us.cloud.langfuse.com", }, } - ] + ], } ), ) @@ -1657,6 +1658,7 @@ async def test_add_callback_via_key_litellm_pre_call_utils( team_id=None, max_parallel_requests=None, metadata={ + "allow_client_mock_response": True, "logging": [ { "callback_name": "langfuse", @@ -1667,7 +1669,7 @@ async def test_add_callback_via_key_litellm_pre_call_utils( "langfuse_host": "https://us.cloud.langfuse.com", }, } - ] + ], }, tpm_limit=None, rpm_limit=None, @@ -1813,6 +1815,7 @@ async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket( team_id=None, max_parallel_requests=None, metadata={ + "allow_client_mock_response": True, "logging": [ { "callback_name": "gcs_bucket", @@ -1822,7 +1825,7 @@ async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket( "gcs_path_service_account": "pathrise-convert-1606954137718-a956eef1a2a8.json", }, } - ] + ], }, tpm_limit=None, rpm_limit=None, @@ -1946,6 +1949,7 @@ async def test_add_callback_via_key_litellm_pre_call_utils_langsmith( team_id=None, max_parallel_requests=None, metadata={ + "allow_client_mock_response": True, "logging": [ { "callback_name": "langsmith", @@ -1956,7 +1960,7 @@ async def test_add_callback_via_key_litellm_pre_call_utils_langsmith( "langsmith_base_url": "https://api.smith.langchain.com", }, } - ] + ], }, tpm_limit=None, rpm_limit=None, diff --git a/tests/test_litellm/litellm_core_utils/test_redact_messages.py b/tests/test_litellm/litellm_core_utils/test_redact_messages.py index 109e344d72..60cfff6e4a 100644 --- a/tests/test_litellm/litellm_core_utils/test_redact_messages.py +++ b/tests/test_litellm/litellm_core_utils/test_redact_messages.py @@ -5,10 +5,17 @@ Covers the proxy flow where headers arrive in litellm_params["metadata"]["header but litellm_params["litellm_metadata"] is None. """ +from types import SimpleNamespace + import pytest import litellm -from litellm.litellm_core_utils.redact_messages import should_redact_message_logging +from litellm.litellm_core_utils.redact_messages import ( + _redact_responses_api_output, + perform_redaction, + should_redact_message_logging, +) +from litellm.responses.main import mock_responses_api_response @pytest.fixture(autouse=True) @@ -68,8 +75,7 @@ class TestShouldRedactMessageLogging: assert should_redact_message_logging(details) is True def test_disable_redaction_via_header_proxy_flow(self): - """litellm-disable-message-redaction should suppress redaction - even when global setting is on, and litellm_metadata is None.""" + """Core helper still honors the explicit disable-redaction header.""" litellm.turn_off_message_logging = True details = _make_model_call_details( metadata_headers={"litellm-disable-message-redaction": "true"}, @@ -77,6 +83,14 @@ class TestShouldRedactMessageLogging: ) assert should_redact_message_logging(details) is False + def test_disable_redaction_via_header_when_global_off(self): + """litellm-disable-message-redaction is still honored when global redaction is off.""" + details = _make_model_call_details( + metadata_headers={"litellm-disable-message-redaction": "true"}, + litellm_metadata=None, + ) + assert should_redact_message_logging(details) is False + # ---- SDK direct-call flow: headers in litellm_metadata ---- def test_enable_redaction_via_header_in_litellm_metadata(self): @@ -127,6 +141,16 @@ class TestShouldRedactMessageLogging: ) assert should_redact_message_logging(details) is False + def test_dynamic_param_false_overrides_global_redaction(self): + """Dynamic turn_off_message_logging=False should take precedence.""" + litellm.turn_off_message_logging = True + details = _make_model_call_details( + metadata_headers={}, + litellm_metadata=None, + standard_callback_dynamic_params={"turn_off_message_logging": False}, + ) + assert should_redact_message_logging(details) is False + # ---- non-dict metadata safety ---- def test_both_metadata_fields_none(self): @@ -145,3 +169,183 @@ class TestShouldRedactMessageLogging: litellm_metadata=None, ) assert should_redact_message_logging(details) is True + + +class TestPerformRedaction: + def test_redacts_standard_logging_and_responses_api_dicts(self): + details = { + "messages": [{"role": "user", "content": "sensitive input"}], + "prompt": "sensitive prompt", + "input": "sensitive input", + "standard_logging_object": { + "messages": [{"role": "user", "content": "sensitive input"}], + "response": { + "output": [ + {"text": "top-level text"}, + {"content": [{"text": "nested text"}]}, + {"type": "reasoning", "summary": [{"text": "reasoning"}]}, + ], + "usage": {"total_tokens": 1}, + }, + }, + } + result = { + "output": [ + {"text": "top-level result"}, + {"content": [{"text": "nested result"}]}, + {"type": "reasoning", "summary": [{"text": "reasoning result"}]}, + ], + "usage": {"total_tokens": 1}, + } + + redacted = perform_redaction(details, result) + + assert details["messages"] == [ + {"role": "user", "content": "redacted-by-litellm"} + ] + assert details["prompt"] == "" + assert details["input"] == "" + + logged_response = details["standard_logging_object"]["response"] + assert logged_response["usage"] == {"total_tokens": 1} + assert logged_response["output"][0]["text"] == "redacted-by-litellm" + assert logged_response["output"][1]["content"][0]["text"] == ( + "redacted-by-litellm" + ) + assert logged_response["output"][2]["summary"][0]["text"] == ( + "redacted-by-litellm" + ) + + assert redacted["usage"] == {"total_tokens": 1} + assert redacted["output"][0]["text"] == "redacted-by-litellm" + assert redacted["output"][1]["content"][0]["text"] == "redacted-by-litellm" + assert redacted["output"][2]["summary"][0]["text"] == "redacted-by-litellm" + assert result["output"][0]["text"] == "top-level result" + + def test_redacts_model_response_dict_choices(self): + result = { + "choices": [ + { + "message": { + "content": "message content", + "reasoning_content": "message reasoning", + "thinking_blocks": ["thinking"], + "audio": {"data": "audio"}, + } + }, + { + "delta": { + "content": "delta content", + "reasoning_content": "delta reasoning", + "thinking_blocks": ["delta thinking"], + "audio": {"data": "audio"}, + } + }, + ] + } + + redacted = perform_redaction({}, result) + + message = redacted["choices"][0]["message"] + assert message["content"] == "redacted-by-litellm" + assert message["reasoning_content"] == "redacted-by-litellm" + assert message["thinking_blocks"] is None + assert message["audio"] is None + + delta = redacted["choices"][1]["delta"] + assert delta["content"] == "redacted-by-litellm" + assert delta["reasoning_content"] == "redacted-by-litellm" + assert delta["thinking_blocks"] is None + assert delta["audio"] is None + + def test_redacts_standard_logging_model_response_dict_choices(self): + details = { + "standard_logging_object": { + "response": { + "choices": [ + { + "message": { + "content": "message content", + "reasoning_content": "message reasoning", + "thinking_blocks": ["thinking"], + "audio": {"data": "audio"}, + } + }, + { + "delta": { + "content": "delta content", + "reasoning_content": "delta reasoning", + "thinking_blocks": ["delta thinking"], + "audio": {"data": "audio"}, + } + }, + ] + } + } + } + + perform_redaction(details, None) + + choices = details["standard_logging_object"]["response"]["choices"] + message = choices[0]["message"] + assert message["content"] == "redacted-by-litellm" + assert message["reasoning_content"] == "redacted-by-litellm" + assert message["thinking_blocks"] is None + assert message["audio"] is None + + delta = choices[1]["delta"] + assert delta["content"] == "redacted-by-litellm" + assert delta["reasoning_content"] == "redacted-by-litellm" + assert delta["thinking_blocks"] is None + assert delta["audio"] is None + + def test_redacts_object_choices_inside_model_response_dict(self): + result = { + "choices": [ + litellm.Choices( + message=litellm.Message( + content="message content", + role="assistant", + reasoning_content="message reasoning", + ) + ) + ] + } + + redacted = perform_redaction({}, result) + + choice = redacted["choices"][0] + assert choice.message.content == "redacted-by-litellm" + assert choice.message.reasoning_content == "redacted-by-litellm" + + def test_redacts_response_output_objects_with_top_level_text(self): + output_items = [ + SimpleNamespace(text="top-level output"), + "non-dict output item", + ] + + _redact_responses_api_output(output_items) + + assert output_items[0].text == "redacted-by-litellm" + assert output_items[1] == "non-dict output item" + + def test_skips_non_dict_response_output_items(self): + result = { + "output": [ + "non-dict output item", + {"content": [{"text": "nested result"}]}, + ] + } + + redacted = perform_redaction({}, result) + + assert redacted["output"][0] == "non-dict output item" + assert redacted["output"][1]["content"][0]["text"] == "redacted-by-litellm" + + def test_redacts_responses_api_response_object(self): + response = mock_responses_api_response("sensitive output") + + redacted = perform_redaction({}, response) + + assert redacted.output[0].content[0].text == "redacted-by-litellm" + assert response.output[0].content[0].text == "sensitive output" diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py index 7dfdca118c..a3247d2e55 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py @@ -1853,6 +1853,60 @@ async def test_make_bedrock_api_request_logging_event_type_for_spend_logs(): assert mock_log.call_args.kwargs["event_type"] == GuardrailEventHooks.pre_call +@pytest.mark.asyncio +async def test_make_bedrock_api_request_filters_dynamic_evaluation_overrides(): + guardrail = BedrockGuardrail( + guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT" + ) + mock_credentials = MagicMock() + mock_credentials.access_key = "test-access-key" + mock_credentials.secret_key = "test-secret-key" + mock_credentials.token = None + + mock_bedrock_response = MagicMock() + mock_bedrock_response.status_code = 200 + mock_bedrock_response.json.return_value = {"action": "NONE", "assessments": []} + + prepared_request = MagicMock() + prepared_request.url = "https://bedrock.test/apply" + prepared_request.body = b"{}" + prepared_request.headers = {} + + with ( + patch.object( + guardrail.async_handler, "post", new_callable=AsyncMock + ) as mock_post, + patch.object( + guardrail, "_load_credentials", return_value=(mock_credentials, "us-east-1") + ), + patch.object( + guardrail, "_prepare_request", return_value=prepared_request + ) as mock_prepare_request, + patch.object( + guardrail, + "get_guardrail_dynamic_request_body_params", + return_value={ + "content": [{"text": {"text": "benign replacement"}}], + "source": "OUTPUT", + "outputScope": "FULL", + }, + ), + ): + mock_post.return_value = mock_bedrock_response + + await guardrail.make_bedrock_api_request( + source="INPUT", + messages=[{"role": "user", "content": "actual prompt"}], + request_data={"model": "gpt-4o"}, + ) + + prepared_data = mock_prepare_request.call_args.kwargs["data"] + assert prepared_data["source"] == "INPUT" + assert "actual prompt" in json.dumps(prepared_data["content"]) + assert "benign replacement" not in json.dumps(prepared_data["content"]) + assert prepared_data["outputScope"] == "FULL" + + @pytest.mark.asyncio async def test_during_call_hook_invokes_bedrock_async_moderation_hook(): """ diff --git a/tests/test_litellm/proxy/guardrails/test_pillar_guardrails.py b/tests/test_litellm/proxy/guardrails/test_pillar_guardrails.py index 6f9e030269..c977223bfa 100644 --- a/tests/test_litellm/proxy/guardrails/test_pillar_guardrails.py +++ b/tests/test_litellm/proxy/guardrails/test_pillar_guardrails.py @@ -505,6 +505,38 @@ def test_get_logging_caching_headers_pillar_metadata(): ) +def test_get_logging_caching_headers_ignores_untrusted_pillar_headers(): + request_data = { + "metadata": { + "pillar_response_headers": { + "set-cookie": "session=evil", + "x-pillar-flagged": "true", + }, + "pillar_flagged": True, + } + } + + headers = get_logging_caching_headers(request_data) + + assert "set-cookie" not in headers + assert "x-pillar-flagged" not in headers + + +def test_get_logging_caching_headers_filters_non_pillar_headers(): + request_data = { + "metadata": { + "pillar_flagged": True, + } + } + build_pillar_response_headers(request_data["metadata"]) + request_data["metadata"]["pillar_response_headers"]["set-cookie"] = "session=evil" + + headers = get_logging_caching_headers(request_data) + + assert headers["x-pillar-flagged"] == "true" + assert "set-cookie" not in headers + + def test_get_logging_caching_headers_truncates_large_evidence(): long_text = "悪" * 6000 # multi-byte unicode to test URL encoding and truncation request_data = { diff --git a/tests/test_litellm/proxy/test_litellm_pre_call_utils.py b/tests/test_litellm/proxy/test_litellm_pre_call_utils.py index 1e5c688e36..8cb9c2bcaa 100644 --- a/tests/test_litellm/proxy/test_litellm_pre_call_utils.py +++ b/tests/test_litellm/proxy/test_litellm_pre_call_utils.py @@ -512,6 +512,247 @@ async def test_add_litellm_data_to_request_strips_string_encoded_admin_injection assert "_pipeline_managed_guardrails" not in other +@pytest.mark.asyncio +async def test_add_litellm_data_to_request_strips_user_control_fields(): + """Strip untrusted proxy-control fields before guardrails, logging, and headers read metadata.""" + request_mock = MagicMock(spec=Request) + request_mock.url.path = "/v1/chat/completions" + request_mock.url = MagicMock() + request_mock.url.__str__.return_value = "http://localhost/v1/chat/completions" + request_mock.method = "POST" + request_mock.query_params = {} + request_mock.headers = {"Content-Type": "application/json"} + request_mock.client = MagicMock() + request_mock.client.host = "127.0.0.1" + + malicious_metadata = { + "disable_global_guardrails": True, + "opted_out_global_guardrails": ["pii"], + "pillar_response_headers": {"set-cookie": "session=evil"}, + "_pillar_response_headers_trusted": True, + "pillar_flagged": True, + "pillar_scanners": {"jailbreak": True}, + "pillar_evidence": [{"evidence": "spoofed"}], + "pillar_session_id_response": "spoofed-session", + "applied_guardrails": ["spoofed"], + "applied_policies": ["spoofed-policy"], + "policy_sources": {"spoofed-policy": "request"}, + "_guardrail_pipelines": [{"name": "spoofed"}], + "_pipeline_managed_guardrails": ["evaded"], + "safe_user_metadata": "kept", + } + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}], + "mock_response": "free response", + "mock_tool_calls": [{"id": "call_1"}], + "disable_global_guardrails": True, + "metadata": copy.deepcopy(malicious_metadata), + "litellm_metadata": copy.deepcopy(malicious_metadata), + } + + updated = await add_litellm_data_to_request( + data=data, + request=request_mock, + user_api_key_dict=UserAPIKeyAuth(api_key="hashed-key"), + proxy_config=MagicMock(), + general_settings={}, + version="test-version", + ) + + assert "mock_response" not in updated + assert "mock_tool_calls" not in updated + assert "disable_global_guardrails" not in updated + + stripped_keys = { + "disable_global_guardrails", + "opted_out_global_guardrails", + "pillar_response_headers", + "_pillar_response_headers_trusted", + "pillar_flagged", + "pillar_scanners", + "pillar_evidence", + "pillar_session_id_response", + "applied_guardrails", + "applied_policies", + "policy_sources", + "_guardrail_pipelines", + "_pipeline_managed_guardrails", + } + for metadata_key in ("metadata", "litellm_metadata"): + cleaned_metadata = updated.get(metadata_key) or {} + for stripped_key in stripped_keys: + assert stripped_key not in cleaned_metadata + assert cleaned_metadata.get("safe_user_metadata") == "kept" + + requester_metadata = updated["metadata"]["requester_metadata"] + for stripped_key in stripped_keys: + assert stripped_key not in requester_metadata + + snapshot_body = updated["proxy_server_request"]["body"] + assert "mock_response" not in snapshot_body + assert "mock_tool_calls" not in snapshot_body + assert "pillar_response_headers" not in snapshot_body["metadata"] + + +@pytest.mark.asyncio +async def test_add_litellm_data_to_request_allows_client_mock_response_with_admin_opt_in(): + request_mock = MagicMock(spec=Request) + request_mock.url.path = "/v1/chat/completions" + request_mock.url = MagicMock() + request_mock.url.__str__.return_value = "http://localhost/v1/chat/completions" + request_mock.method = "POST" + request_mock.query_params = {} + request_mock.headers = {"Content-Type": "application/json"} + request_mock.client = MagicMock() + request_mock.client.host = "127.0.0.1" + + updated = await add_litellm_data_to_request( + data={ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}], + "mock_response": "allowed mock", + "mock_tool_calls": [{"id": "call_1"}], + }, + request=request_mock, + user_api_key_dict=UserAPIKeyAuth( + api_key="hashed-key", + metadata={"allow_client_mock_response": True}, + ), + proxy_config=MagicMock(), + general_settings={}, + version="test-version", + ) + + assert updated["mock_response"] == "allowed mock" + assert updated["mock_tool_calls"] == [{"id": "call_1"}] + + +@pytest.mark.asyncio +async def test_add_litellm_data_to_request_strips_client_redaction_bypass_controls(): + request_mock = MagicMock(spec=Request) + request_mock.url.path = "/v1/chat/completions" + request_mock.url = MagicMock() + request_mock.url.__str__.return_value = "http://localhost/v1/chat/completions" + request_mock.method = "POST" + request_mock.query_params = {} + request_mock.headers = { + "Content-Type": "application/json", + "litellm-disable-message-redaction": "true", + } + request_mock.client = MagicMock() + request_mock.client.host = "127.0.0.1" + + original_turn_off_message_logging = litellm.turn_off_message_logging + litellm.turn_off_message_logging = True + try: + updated = await add_litellm_data_to_request( + data={ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}], + "turn_off_message_logging": False, + "metadata": {"headers": {"litellm-disable-message-redaction": "true"}}, + "litellm_metadata": json.dumps( + {"headers": {"LiteLLM-Disable-Message-Redaction": "true"}} + ), + }, + request=request_mock, + user_api_key_dict=UserAPIKeyAuth(api_key="hashed-key"), + proxy_config=MagicMock(), + general_settings={}, + version="test-version", + ) + finally: + litellm.turn_off_message_logging = original_turn_off_message_logging + + assert "turn_off_message_logging" not in updated + assert "litellm-disable-message-redaction" not in { + header.lower() for header in updated["metadata"]["headers"] + } + assert "litellm-disable-message-redaction" not in { + header.lower() + for header in updated["metadata"]["requester_metadata"].get("headers", {}) + } + assert "litellm-disable-message-redaction" not in { + header.lower() for header in updated["proxy_server_request"]["headers"] + } + assert "litellm-disable-message-redaction" not in { + header.lower() + for header in updated["proxy_server_request"]["body"]["metadata"]["headers"] + } + assert "litellm-disable-message-redaction" not in { + header.lower() + for header in (updated.get("litellm_metadata") or {}).get("headers", {}) + } + + +@pytest.mark.parametrize( + "auth_kwargs", + [ + {"metadata": {"allow_client_message_redaction_opt_out": True}}, + {"team_metadata": {"allow_client_message_redaction_opt_out": True}}, + ], +) +@pytest.mark.asyncio +async def test_add_litellm_data_to_request_allows_redaction_opt_out_with_admin_opt_in( + auth_kwargs, +): + request_mock = MagicMock(spec=Request) + request_mock.url.path = "/v1/chat/completions" + request_mock.url = MagicMock() + request_mock.url.__str__.return_value = "http://localhost/v1/chat/completions" + request_mock.method = "POST" + request_mock.query_params = {} + request_mock.headers = { + "Content-Type": "application/json", + "litellm-disable-message-redaction": "true", + } + request_mock.client = MagicMock() + request_mock.client.host = "127.0.0.1" + + original_turn_off_message_logging = litellm.turn_off_message_logging + litellm.turn_off_message_logging = True + try: + updated = await add_litellm_data_to_request( + data={ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}], + "turn_off_message_logging": False, + "metadata": {"headers": {"litellm-disable-message-redaction": "true"}}, + "litellm_metadata": json.dumps( + {"headers": {"LiteLLM-Disable-Message-Redaction": "true"}} + ), + }, + request=request_mock, + user_api_key_dict=UserAPIKeyAuth(api_key="hashed-key", **auth_kwargs), + proxy_config=MagicMock(), + general_settings={}, + version="test-version", + ) + finally: + litellm.turn_off_message_logging = original_turn_off_message_logging + + assert updated["turn_off_message_logging"] is False + assert "litellm-disable-message-redaction" in { + header.lower() for header in updated["metadata"]["headers"] + } + assert "litellm-disable-message-redaction" in { + header.lower() + for header in updated["metadata"]["requester_metadata"].get("headers", {}) + } + assert "litellm-disable-message-redaction" in { + header.lower() for header in updated["proxy_server_request"]["headers"] + } + assert "litellm-disable-message-redaction" in { + header.lower() + for header in updated["proxy_server_request"]["body"]["metadata"]["headers"] + } + assert "litellm-disable-message-redaction" in { + header.lower() + for header in (updated.get("litellm_metadata") or {}).get("headers", {}) + } + + @pytest.mark.asyncio async def test_add_litellm_data_to_request_ignores_x_litellm_tags_header_without_permission(): """Regression: the `x-litellm-tags` header bypassed the body-metadata @@ -1669,7 +1910,10 @@ async def test_add_litellm_metadata_from_request_headers(): # Create mock user API key dict mock_user_api_key_dict = UserAPIKeyAuth( - api_key="test-key", user_id="test-user", org_id="test-org" + api_key="test-key", + user_id="test-user", + org_id="test-org", + metadata={"allow_client_mock_response": True}, ) # Create mock proxy logging object @@ -1782,7 +2026,9 @@ async def test_anthropic_messages_standard_logging_object_matches_fixture(): mock_fastapi_response = MagicMock(spec=Response) mock_user_api_key_dict = UserAPIKeyAuth( - api_key="test-key", user_id="default_user_id" + api_key="test-key", + user_id="default_user_id", + metadata={"allow_client_mock_response": True}, ) mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) @@ -3326,7 +3572,9 @@ async def test_team_guardrail_merges_with_global_policy(): policy_registry = get_policy_registry() policy_registry._policies = { "global-policy": Policy( - guardrails=PolicyGuardrails(add=["policy-guardrail-1", "policy-guardrail-2"]), + guardrails=PolicyGuardrails( + add=["policy-guardrail-1", "policy-guardrail-2"] + ), ), } policy_registry._initialized = True @@ -3347,14 +3595,18 @@ async def test_team_guardrail_merges_with_global_policy(): guardrails = data["metadata"].get("guardrails", []) - assert "team-direct-guardrail" in guardrails, \ - f"Team guardrail missing from merged list: {guardrails}" - assert "policy-guardrail-1" in guardrails, \ - f"policy-guardrail-1 missing: {guardrails}" - assert "policy-guardrail-2" in guardrails, \ - f"policy-guardrail-2 missing: {guardrails}" - assert len(guardrails) == len(set(guardrails)), \ - f"Duplicates in guardrails list: {guardrails}" + assert ( + "team-direct-guardrail" in guardrails + ), f"Team guardrail missing from merged list: {guardrails}" + assert ( + "policy-guardrail-1" in guardrails + ), f"policy-guardrail-1 missing: {guardrails}" + assert ( + "policy-guardrail-2" in guardrails + ), f"policy-guardrail-2 missing: {guardrails}" + assert len(guardrails) == len( + set(guardrails) + ), f"Duplicates in guardrails list: {guardrails}" # Verify get_guardrail_from_metadata returns the merged list even # when litellm_metadata is present (the bug: it returned [] before fix) @@ -3365,9 +3617,9 @@ async def test_team_guardrail_merges_with_global_policy(): dummy = _DummyGuardrail(guardrail_name="team-direct-guardrail") returned = dummy.get_guardrail_from_metadata(data) - assert "team-direct-guardrail" in returned, ( - f"get_guardrail_from_metadata shadowed by litellm_metadata; got: {returned}" - ) + assert ( + "team-direct-guardrail" in returned + ), f"get_guardrail_from_metadata shadowed by litellm_metadata; got: {returned}" finally: policy_registry._policies = {} @@ -3396,9 +3648,10 @@ async def test_get_guardrail_from_metadata_prefers_metadata_over_litellm_metadat } result = dummy.get_guardrail_from_metadata(data) - assert result == ["my-guardrail", "other-guardrail"], ( - f"Expected guardrails from metadata, got: {result}" - ) + assert result == [ + "my-guardrail", + "other-guardrail", + ], f"Expected guardrails from metadata, got: {result}" def test_get_guardrail_from_metadata_reads_litellm_metadata_when_no_metadata(): @@ -3419,6 +3672,6 @@ def test_get_guardrail_from_metadata_reads_litellm_metadata_when_no_metadata(): } result = dummy.get_guardrail_from_metadata(data) - assert result == ["my-guardrail"], ( - f"Expected guardrails from litellm_metadata fallback, got: {result}" - ) + assert result == [ + "my-guardrail" + ], f"Expected guardrails from litellm_metadata fallback, got: {result}"