diff --git a/litellm/litellm_core_utils/redact_messages.py b/litellm/litellm_core_utils/redact_messages.py index a62031a9c9..79aeeff144 100644 --- a/litellm/litellm_core_utils/redact_messages.py +++ b/litellm/litellm_core_utils/redact_messages.py @@ -69,6 +69,11 @@ def perform_redaction(model_call_details: dict, result): elif isinstance(choice, litellm.utils.StreamingChoices): choice.delta.content = "redacted-by-litellm" return _result + if result is not None and isinstance(result, litellm.EmbeddingResponse): + _result = copy.deepcopy(result) + if hasattr(_result, "data") and _result.data is not None: + _result.data = [] + return _result else: return {"text": "redacted-by-litellm"} diff --git a/tests/litellm_utils_tests/test_utils.py b/tests/litellm_utils_tests/test_utils.py index c053287ce8..b59b95d3f5 100644 --- a/tests/litellm_utils_tests/test_utils.py +++ b/tests/litellm_utils_tests/test_utils.py @@ -604,6 +604,73 @@ def test_redact_msgs_from_logs(): print("Test passed") +def test_redact_embedding_response(): + """ + Tests that EmbeddingResponse redaction preserves critical metadata while clearing sensitive data + + This test ensures that: + 1. usage field is preserved for token/cost tracking + 2. model field is preserved for response structure integrity + 3. data field (containing embeddings) is cleared for privacy + 4. original response object is not modified + """ + from litellm.litellm_core_utils.litellm_logging import Logging + from litellm.litellm_core_utils.redact_messages import ( + redact_message_input_output_from_logging, + ) + + litellm.turn_off_message_logging = True + + # Create a test EmbeddingResponse with usage data + original_usage = litellm.Usage(prompt_tokens=10, completion_tokens=0, total_tokens=10) + original_data = [ + {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3, 0.4, 0.5]}, + {"object": "embedding", "index": 1, "embedding": [0.6, 0.7, 0.8, 0.9, 1.0]} + ] + + response_obj = litellm.EmbeddingResponse( + model="text-embedding-ada-002", + data=original_data, + usage=original_usage, + object="list" + ) + + litellm_logging_obj = Logging( + model="text-embedding-ada-002", + messages=[{"role": "user", "content": "test input"}], + stream=False, + call_type="embedding", + litellm_call_id="1234", + start_time=datetime.now(), + function_id="1234", + ) + + _redacted_response_obj = redact_message_input_output_from_logging( + result=response_obj, + model_call_details=litellm_logging_obj.model_call_details, + ) + + # Assert the original response_obj is NOT modified + assert response_obj.data == original_data + assert response_obj.usage == original_usage + assert response_obj.model == "text-embedding-ada-002" + assert response_obj.object == "list" + + # Assert the redacted response preserves critical metadata + assert _redacted_response_obj.usage == original_usage # usage should be preserved + assert _redacted_response_obj.model == "text-embedding-ada-002" # model should be preserved + assert _redacted_response_obj.object == "list" # object should be preserved + + # Assert sensitive data is cleared + assert _redacted_response_obj.data == [] # data should be cleared + + # Assert it's still an EmbeddingResponse instance + assert isinstance(_redacted_response_obj, litellm.EmbeddingResponse) + + litellm.turn_off_message_logging = False + print("Test passed") + + def test_redact_msgs_from_logs_with_dynamic_params(): """ Tests redaction behavior based on standard_callback_dynamic_params setting: @@ -1028,7 +1095,7 @@ def test_async_http_handler(mock_async_client): with mock.patch.object(AsyncHTTPHandler, '_create_async_transport') as mock_create_transport: mock_transport = mock.MagicMock() mock_create_transport.return_value = mock_transport - + AsyncHTTPHandler(timeout, event_hooks, concurrent_limit) mock_async_client.assert_called_with( @@ -2108,10 +2175,10 @@ def test_get_provider_audio_transcription_config(): def test_claude_3_7_sonnet_supports_pdf_input(model, expected_bool): from litellm.utils import supports_pdf_input - + assert supports_pdf_input(model) == expected_bool - + def test_get_valid_models_from_provider(): """ Test that get_valid_models returns the correct models for a given provider @@ -2163,5 +2230,3 @@ def test_get_valid_models_from_dynamic_api_key(): valid_models = get_valid_models(custom_llm_provider="anthropic", litellm_params=creds, check_provider_endpoint=True) assert len(valid_models) > 0 assert "anthropic/claude-3-7-sonnet-20250219" in valid_models - - \ No newline at end of file