Files
litellm/tests/test_litellm/proxy/db/test_prisma_client.py
T
2025-08-01 15:38:48 -07:00

76 lines
2.6 KiB
Python

import json
import os
import sys
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from litellm.proxy.db.prisma_client import PrismaWrapper, should_update_prisma_schema
def test_should_update_prisma_schema(monkeypatch):
# CASE 1: Environment variable behavior
# When DISABLE_SCHEMA_UPDATE is not set -> should update
monkeypatch.setenv("DISABLE_SCHEMA_UPDATE", None)
assert should_update_prisma_schema() == True
# When DISABLE_SCHEMA_UPDATE="true" -> should not update
monkeypatch.setenv("DISABLE_SCHEMA_UPDATE", "true")
assert should_update_prisma_schema() == False
# When DISABLE_SCHEMA_UPDATE="false" -> should update
monkeypatch.setenv("DISABLE_SCHEMA_UPDATE", "false")
assert should_update_prisma_schema() == True
# CASE 2: Explicit parameter behavior (overrides env var)
monkeypatch.setenv("DISABLE_SCHEMA_UPDATE", None)
assert should_update_prisma_schema(True) == False # Param True -> should not update
monkeypatch.setenv("DISABLE_SCHEMA_UPDATE", None) # Set env var opposite to param
assert should_update_prisma_schema(False) == True # Param False -> should update
@pytest.mark.asyncio
async def test_recreate_prisma_client_successful_disconnect():
"""
Test that recreate_prisma_client works normally when disconnect succeeds.
"""
# Mock the original prisma client
mock_prisma = AsyncMock()
# Create a mock PrismaWrapper instance
wrapper = Mock()
wrapper._original_prisma = mock_prisma
# Configure disconnect to succeed
mock_prisma.disconnect.return_value = None
# Mock the entire recreate_prisma_client method to avoid import issues
async def mock_recreate_prisma_client(new_db_url: str, http_client=None):
try:
await mock_prisma.disconnect()
except Exception:
pass
mock_new_prisma = AsyncMock()
wrapper._original_prisma = mock_new_prisma
await mock_new_prisma.connect()
# Assign the mock method to the wrapper
wrapper.recreate_prisma_client = mock_recreate_prisma_client
# Call the method
await wrapper.recreate_prisma_client("postgresql://new:new@localhost:5432/new")
# Verify that disconnect was called
mock_prisma.disconnect.assert_called_once()
# Verify that the new client replaced the original
assert wrapper._original_prisma != mock_prisma
assert hasattr(wrapper._original_prisma, 'connect')