mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 22:48:35 +00:00
459 lines
16 KiB
Python
459 lines
16 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import os
|
|
from copy import deepcopy
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from urllib.parse import unquote
|
|
|
|
import litellm
|
|
import pytest
|
|
|
|
from litellm.integrations.sqs import SQSLogger
|
|
from litellm.types.utils import StandardLoggingPayload
|
|
|
|
from litellm.litellm_core_utils.app_crypto import AppCrypto
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_sqs_logger_flush():
|
|
expected_queue_url = "https://sqs.us-east-1.amazonaws.com/123456789012/test-queue"
|
|
expected_region = "us-east-1"
|
|
|
|
sqs_logger = SQSLogger(
|
|
sqs_queue_url=expected_queue_url,
|
|
sqs_region_name=expected_region,
|
|
sqs_flush_interval=1,
|
|
)
|
|
|
|
# Mock the httpx client
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status = MagicMock()
|
|
sqs_logger.async_httpx_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
litellm.callbacks = [sqs_logger]
|
|
|
|
await litellm.acompletion(
|
|
model="gpt-4o",
|
|
messages=[{"role": "user", "content": "hello"}],
|
|
mock_response="hi",
|
|
)
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
# Verify that httpx post was called
|
|
sqs_logger.async_httpx_client.post.assert_called()
|
|
|
|
# Get the call arguments
|
|
call_args = sqs_logger.async_httpx_client.post.call_args
|
|
|
|
# Verify the URL is correct
|
|
called_url = call_args[0][0] # First positional argument
|
|
assert called_url == expected_queue_url, f"Expected URL {expected_queue_url}, got {called_url}"
|
|
|
|
# Verify the payload contains StandardLoggingPayload data
|
|
called_data = call_args.kwargs['data']
|
|
|
|
# Extract the MessageBody from the URL-encoded data
|
|
# Format: "Action=SendMessage&Version=2012-11-05&MessageBody=<url_encoded_json>"
|
|
assert "Action=SendMessage" in called_data
|
|
assert "Version=2012-11-05" in called_data
|
|
assert "MessageBody=" in called_data
|
|
|
|
# Extract and decode the message body
|
|
message_body_start = called_data.find("MessageBody=") + len("MessageBody=")
|
|
message_body_encoded = called_data[message_body_start:]
|
|
message_body_json = unquote(message_body_encoded)
|
|
|
|
# Parse the JSON to verify it's a StandardLoggingPayload
|
|
payload_data = json.loads(message_body_json)
|
|
|
|
# Verify it has the expected StandardLoggingPayload structure
|
|
assert "model" in payload_data
|
|
assert "messages" in payload_data
|
|
assert "response" in payload_data
|
|
assert payload_data["model"] == "gpt-4o"
|
|
assert len(payload_data["messages"]) == 1
|
|
assert payload_data["messages"][0]["role"] == "user"
|
|
assert payload_data["messages"][0]["content"] == "hello"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_sqs_logger_error_flush():
|
|
expected_queue_url = "https://sqs.us-east-1.amazonaws.com/123456789012/test-queue"
|
|
expected_region = "us-east-1"
|
|
|
|
sqs_logger = SQSLogger(
|
|
sqs_queue_url=expected_queue_url,
|
|
sqs_region_name=expected_region,
|
|
sqs_flush_interval=1,
|
|
)
|
|
|
|
# Mock the httpx client
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status = Exception("Something went wrong")
|
|
sqs_logger.async_httpx_client.post = AsyncMock(return_value=mock_response)
|
|
|
|
litellm.callbacks = [sqs_logger]
|
|
|
|
await litellm.acompletion(
|
|
model="gpt-4o",
|
|
messages=[{"role": "user", "content": "hello"}],
|
|
mock_response="Error occurred"
|
|
)
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
# Verify that httpx post was called
|
|
sqs_logger.async_httpx_client.post.assert_called()
|
|
|
|
# Get the call arguments
|
|
call_args = sqs_logger.async_httpx_client.post.call_args
|
|
|
|
# Verify the URL is correct
|
|
called_url = call_args[0][0] # First positional argument
|
|
assert called_url == expected_queue_url, f"Expected URL {expected_queue_url}, got {called_url}"
|
|
|
|
# Verify the payload contains StandardLoggingPayload data
|
|
called_data = call_args.kwargs['data']
|
|
|
|
# Extract the MessageBody from the URL-encoded data
|
|
# Format: "Action=SendMessage&Version=2012-11-05&MessageBody=<url_encoded_json>"
|
|
assert "Action=SendMessage" in called_data
|
|
assert "Version=2012-11-05" in called_data
|
|
assert "MessageBody=" in called_data
|
|
|
|
# Extract and decode the message body
|
|
message_body_start = called_data.find("MessageBody=") + len("MessageBody=")
|
|
message_body_encoded = called_data[message_body_start:]
|
|
message_body_json = unquote(message_body_encoded)
|
|
|
|
# Parse the JSON to verify it's a StandardLoggingPayload
|
|
payload_data = json.loads(message_body_json)
|
|
|
|
# Verify it has the expected StandardLoggingPayload structure
|
|
assert "model" in payload_data
|
|
assert "messages" in payload_data
|
|
assert "response" in payload_data
|
|
assert payload_data["model"] == "gpt-4o"
|
|
assert len(payload_data["messages"]) == 1
|
|
assert payload_data["messages"][0]["role"] == "user"
|
|
assert payload_data["messages"][0]["content"] == "hello"
|
|
|
|
|
|
|
|
# =============================================================================
|
|
# 📥 Logging Queue Tests
|
|
# =============================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_log_success_event_adds_to_queue(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
|
|
|
|
fake_payload = {"some": "data"}
|
|
await logger.async_log_success_event(
|
|
{"standard_logging_object": fake_payload}, None, None, None
|
|
)
|
|
assert fake_payload in logger.log_queue
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_log_failure_event_adds_to_queue(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
|
|
|
|
fake_payload = {"fail": True}
|
|
await logger.async_log_failure_event(
|
|
{"standard_logging_object": fake_payload}, None, None, None
|
|
)
|
|
assert fake_payload in logger.log_queue
|
|
|
|
|
|
|
|
# =============================================================================
|
|
# 🧾 async_send_batch Tests
|
|
# =============================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_send_batch_triggers_tasks(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
|
|
logger.async_send_message = AsyncMock()
|
|
|
|
logger.log_queue = [{"log": 1}, {"log": 2}]
|
|
await logger.async_send_batch()
|
|
|
|
assert logger.async_send_message.await_count == 0 # uses create_task internally
|
|
|
|
|
|
|
|
# =============================================================================
|
|
# 🔐 AppCrypto Tests
|
|
# =============================================================================
|
|
|
|
def test_appcrypto_encrypt_decrypt_roundtrip():
|
|
key = os.urandom(32)
|
|
crypto = AppCrypto(key)
|
|
data = {"event": "test", "value": 42}
|
|
aad = b"context"
|
|
enc = crypto.encrypt_json(data, aad=aad)
|
|
dec = crypto.decrypt_json(enc, aad=aad)
|
|
assert dec == data
|
|
|
|
|
|
def test_appcrypto_invalid_key_length():
|
|
with pytest.raises(ValueError, match="32 bytes"):
|
|
AppCrypto(b"short")
|
|
|
|
|
|
# =============================================================================
|
|
# 🪣 SQSLogger Initialization Tests
|
|
# =============================================================================
|
|
|
|
def test_sqs_logger_init_without_encryption(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
# Patch asyncio.create_task to avoid RuntimeError
|
|
monkeypatch.setattr(asyncio, "create_task", MagicMock())
|
|
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
|
|
assert logger.sqs_queue_url == "https://example.com"
|
|
assert logger.app_crypto is None
|
|
|
|
|
|
def test_sqs_logger_init_with_encryption(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
monkeypatch.setattr(asyncio, "create_task", MagicMock())
|
|
key_b64 = base64.b64encode(os.urandom(32)).decode()
|
|
|
|
logger = SQSLogger(
|
|
sqs_queue_url="https://example.com",
|
|
sqs_region_name="us-west-2",
|
|
sqs_aws_use_application_level_encryption=True,
|
|
sqs_app_encryption_key_b64=key_b64,
|
|
sqs_app_encryption_aad="tenant=bill",
|
|
)
|
|
assert logger.app_crypto is not None
|
|
assert logger.sqs_app_encryption_aad == "tenant=bill"
|
|
|
|
|
|
def test_sqs_logger_init_with_encryption_missing_key(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
monkeypatch.setattr(asyncio, "create_task", MagicMock())
|
|
with pytest.raises(ValueError, match="required when encryption is enabled"):
|
|
SQSLogger(
|
|
sqs_queue_url="https://example.com",
|
|
sqs_region_name="us-west-2",
|
|
sqs_aws_use_application_level_encryption=True,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# 📥 Logging Queue Tests
|
|
# =============================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_log_success_event_adds_to_queue(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
monkeypatch.setattr(asyncio, "create_task", MagicMock())
|
|
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
|
|
|
|
fake_payload = {"some": "data"}
|
|
await logger.async_log_success_event(
|
|
{"standard_logging_object": fake_payload}, None, None, None
|
|
)
|
|
assert fake_payload in logger.log_queue
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_log_failure_event_adds_to_queue(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
monkeypatch.setattr(asyncio, "create_task", MagicMock())
|
|
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
|
|
|
|
fake_payload = {"fail": True}
|
|
await logger.async_log_failure_event(
|
|
{"standard_logging_object": fake_payload}, None, None, None
|
|
)
|
|
assert fake_payload in logger.log_queue
|
|
|
|
|
|
# =============================================================================
|
|
# 🧾 async_send_batch Tests
|
|
# =============================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_send_batch_triggers_tasks(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
monkeypatch.setattr(asyncio, "create_task", MagicMock())
|
|
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
|
|
|
|
logger.async_send_message = AsyncMock()
|
|
logger.log_queue = [{"log": 1}, {"log": 2}]
|
|
|
|
await logger.async_send_batch()
|
|
# It uses asyncio.create_task() so direct await count = 0 is expected
|
|
asyncio.create_task.assert_called()
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_strip_base64_removes_file_and_nontext_entries():
|
|
logger = SQSLogger(sqs_strip_base64_files=True)
|
|
|
|
payload = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "Hello world"},
|
|
{"type": "image", "file": {"file_data": "data:image/png;base64,AAAA"}},
|
|
{"type": "file", "file": {"file_data": "data:application/pdf;base64,BBBB"}},
|
|
],
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": [
|
|
{"type": "text", "text": "Response"},
|
|
{"type": "audio", "file": {"file_data": "data:audio/wav;base64,CCCC"}},
|
|
],
|
|
},
|
|
]
|
|
}
|
|
|
|
stripped = await logger._strip_base64_from_messages(payload)
|
|
|
|
# 1️⃣ All file/image/audio entries removed
|
|
assert len(stripped["messages"][0]["content"]) == 1
|
|
assert stripped["messages"][0]["content"][0]["text"] == "Hello world"
|
|
|
|
assert len(stripped["messages"][1]["content"]) == 1
|
|
assert stripped["messages"][1]["content"][0]["text"] == "Response"
|
|
|
|
# 2️⃣ No residual 'file' keys left
|
|
for msg in stripped["messages"]:
|
|
for content in msg["content"]:
|
|
assert "file" not in content
|
|
assert content.get("type") == "text"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_strip_base64_keeps_non_file_content():
|
|
logger = SQSLogger(sqs_strip_base64_files=True)
|
|
|
|
payload = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "Just text"},
|
|
{"type": "text", "text": "Another message"},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
|
|
stripped = await logger._strip_base64_from_messages(payload)
|
|
|
|
# Should not modify normal text messages
|
|
assert stripped["messages"][0]["content"] == payload["messages"][0]["content"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_strip_base64_handles_empty_or_missing_messages():
|
|
logger = SQSLogger(sqs_strip_base64_files=True)
|
|
|
|
payload_no_messages = {}
|
|
stripped1 = await logger._strip_base64_from_messages(payload_no_messages)
|
|
assert stripped1 == payload_no_messages
|
|
|
|
payload_empty = {"messages": []}
|
|
stripped2 = await logger._strip_base64_from_messages(payload_empty)
|
|
assert stripped2 == payload_empty
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_strip_base64_mixed_nested_objects():
|
|
"""
|
|
Handles weird/nested content structures gracefully.
|
|
"""
|
|
logger = SQSLogger(sqs_strip_base64_files=True)
|
|
|
|
payload = {
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": [
|
|
{"type": "text", "text": "Keep me"},
|
|
{"type": "custom", "metadata": "ignore but non-text"},
|
|
{"foo": "bar"},
|
|
{"file": {"file_data": "data:application/pdf;base64,XXX"}},
|
|
],
|
|
"extra": {"trace_id": "123"},
|
|
}
|
|
]
|
|
}
|
|
|
|
stripped = await logger._strip_base64_from_messages(payload)
|
|
|
|
# 'custom' (non-text) and 'file' entries removed
|
|
content = stripped["messages"][0]["content"]
|
|
assert len(content) == 2
|
|
assert {"type": "text", "text": "Keep me"} in content
|
|
assert {"foo": "bar"} in content
|
|
# Other metadata stays
|
|
assert stripped["messages"][0]["extra"]["trace_id"] == "123"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_strip_base64_recursive_redaction():
|
|
logger = SQSLogger(sqs_strip_base64_files=True)
|
|
payload = {
|
|
"messages": [
|
|
{
|
|
"content": [
|
|
{"type": "text", "text": "normal text"},
|
|
{"type": "text", "text": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg"},
|
|
{"type": "text", "text": "Nested: {'data': 'data:application/pdf;base64,AAA...'}"},
|
|
{"file": {"file_data": "data:application/pdf;base64,AAAA"}},
|
|
{"metadata": {"preview": "data:audio/mp3;base64,AAAAA=="}},
|
|
]
|
|
}
|
|
]
|
|
}
|
|
|
|
result = await logger._strip_base64_from_messages(payload)
|
|
content = result["messages"][0]["content"]
|
|
|
|
# Dropped file-type entry
|
|
assert not any("file" in c for c in content)
|
|
# Base64 redacted globally
|
|
for c in content:
|
|
if isinstance(c, dict):
|
|
s = json.dumps(c).lower()
|
|
# allow "[base64_redacted]" but nothing else
|
|
assert "base64," not in s, f"Found real base64 blob in: {s}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_health_check_healthy(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
monkeypatch.setattr(asyncio, "create_task", MagicMock())
|
|
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
|
|
logger.async_send_message = AsyncMock(return_value=None)
|
|
|
|
result = await logger.async_health_check()
|
|
assert result["status"] == "healthy"
|
|
assert result.get("error_message") is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_health_check_unhealthy(monkeypatch):
|
|
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
|
|
monkeypatch.setattr(asyncio, "create_task", MagicMock())
|
|
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
|
|
logger.async_send_message = AsyncMock(side_effect=Exception("boom"))
|
|
|
|
result = await logger.async_health_check()
|
|
assert result["status"] == "unhealthy"
|
|
assert "boom" in (result.get("error_message") or "")
|