feat: enhance redaction functionality for EmbeddingResponse (#12088)

This commit is contained in:
Bougou Nisou
2025-06-28 12:30:26 +08:00
committed by GitHub
parent 6578133bb7
commit 58dda44fda
2 changed files with 75 additions and 5 deletions
@@ -69,6 +69,11 @@ def perform_redaction(model_call_details: dict, result):
elif isinstance(choice, litellm.utils.StreamingChoices):
choice.delta.content = "redacted-by-litellm"
return _result
if result is not None and isinstance(result, litellm.EmbeddingResponse):
_result = copy.deepcopy(result)
if hasattr(_result, "data") and _result.data is not None:
_result.data = []
return _result
else:
return {"text": "redacted-by-litellm"}
+70 -5
View File
@@ -604,6 +604,73 @@ def test_redact_msgs_from_logs():
print("Test passed")
def test_redact_embedding_response():
"""
Tests that EmbeddingResponse redaction preserves critical metadata while clearing sensitive data
This test ensures that:
1. usage field is preserved for token/cost tracking
2. model field is preserved for response structure integrity
3. data field (containing embeddings) is cleared for privacy
4. original response object is not modified
"""
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
)
litellm.turn_off_message_logging = True
# Create a test EmbeddingResponse with usage data
original_usage = litellm.Usage(prompt_tokens=10, completion_tokens=0, total_tokens=10)
original_data = [
{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3, 0.4, 0.5]},
{"object": "embedding", "index": 1, "embedding": [0.6, 0.7, 0.8, 0.9, 1.0]}
]
response_obj = litellm.EmbeddingResponse(
model="text-embedding-ada-002",
data=original_data,
usage=original_usage,
object="list"
)
litellm_logging_obj = Logging(
model="text-embedding-ada-002",
messages=[{"role": "user", "content": "test input"}],
stream=False,
call_type="embedding",
litellm_call_id="1234",
start_time=datetime.now(),
function_id="1234",
)
_redacted_response_obj = redact_message_input_output_from_logging(
result=response_obj,
model_call_details=litellm_logging_obj.model_call_details,
)
# Assert the original response_obj is NOT modified
assert response_obj.data == original_data
assert response_obj.usage == original_usage
assert response_obj.model == "text-embedding-ada-002"
assert response_obj.object == "list"
# Assert the redacted response preserves critical metadata
assert _redacted_response_obj.usage == original_usage # usage should be preserved
assert _redacted_response_obj.model == "text-embedding-ada-002" # model should be preserved
assert _redacted_response_obj.object == "list" # object should be preserved
# Assert sensitive data is cleared
assert _redacted_response_obj.data == [] # data should be cleared
# Assert it's still an EmbeddingResponse instance
assert isinstance(_redacted_response_obj, litellm.EmbeddingResponse)
litellm.turn_off_message_logging = False
print("Test passed")
def test_redact_msgs_from_logs_with_dynamic_params():
"""
Tests redaction behavior based on standard_callback_dynamic_params setting:
@@ -1028,7 +1095,7 @@ def test_async_http_handler(mock_async_client):
with mock.patch.object(AsyncHTTPHandler, '_create_async_transport') as mock_create_transport:
mock_transport = mock.MagicMock()
mock_create_transport.return_value = mock_transport
AsyncHTTPHandler(timeout, event_hooks, concurrent_limit)
mock_async_client.assert_called_with(
@@ -2108,10 +2175,10 @@ def test_get_provider_audio_transcription_config():
def test_claude_3_7_sonnet_supports_pdf_input(model, expected_bool):
from litellm.utils import supports_pdf_input
assert supports_pdf_input(model) == expected_bool
def test_get_valid_models_from_provider():
"""
Test that get_valid_models returns the correct models for a given provider
@@ -2163,5 +2230,3 @@ def test_get_valid_models_from_dynamic_api_key():
valid_models = get_valid_models(custom_llm_provider="anthropic", litellm_params=creds, check_provider_endpoint=True)
assert len(valid_models) > 0
assert "anthropic/claude-3-7-sonnet-20250219" in valid_models