mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
76 lines
2.6 KiB
Python
76 lines
2.6 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
|
|
from litellm.proxy.db.prisma_client import PrismaWrapper, should_update_prisma_schema
|
|
|
|
|
|
def test_should_update_prisma_schema(monkeypatch):
|
|
# CASE 1: Environment variable behavior
|
|
# When DISABLE_SCHEMA_UPDATE is not set -> should update
|
|
monkeypatch.setenv("DISABLE_SCHEMA_UPDATE", None)
|
|
assert should_update_prisma_schema() == True
|
|
|
|
# When DISABLE_SCHEMA_UPDATE="true" -> should not update
|
|
monkeypatch.setenv("DISABLE_SCHEMA_UPDATE", "true")
|
|
assert should_update_prisma_schema() == False
|
|
|
|
# When DISABLE_SCHEMA_UPDATE="false" -> should update
|
|
monkeypatch.setenv("DISABLE_SCHEMA_UPDATE", "false")
|
|
assert should_update_prisma_schema() == True
|
|
|
|
# CASE 2: Explicit parameter behavior (overrides env var)
|
|
monkeypatch.setenv("DISABLE_SCHEMA_UPDATE", None)
|
|
assert should_update_prisma_schema(True) == False # Param True -> should not update
|
|
|
|
monkeypatch.setenv("DISABLE_SCHEMA_UPDATE", None) # Set env var opposite to param
|
|
assert should_update_prisma_schema(False) == True # Param False -> should update
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_recreate_prisma_client_successful_disconnect():
|
|
"""
|
|
Test that recreate_prisma_client works normally when disconnect succeeds.
|
|
"""
|
|
# Mock the original prisma client
|
|
mock_prisma = AsyncMock()
|
|
|
|
# Create a mock PrismaWrapper instance
|
|
wrapper = Mock()
|
|
wrapper._original_prisma = mock_prisma
|
|
|
|
# Configure disconnect to succeed
|
|
mock_prisma.disconnect.return_value = None
|
|
|
|
# Mock the entire recreate_prisma_client method to avoid import issues
|
|
async def mock_recreate_prisma_client(new_db_url: str, http_client=None):
|
|
try:
|
|
await mock_prisma.disconnect()
|
|
except Exception:
|
|
pass
|
|
|
|
mock_new_prisma = AsyncMock()
|
|
wrapper._original_prisma = mock_new_prisma
|
|
await mock_new_prisma.connect()
|
|
|
|
# Assign the mock method to the wrapper
|
|
wrapper.recreate_prisma_client = mock_recreate_prisma_client
|
|
|
|
# Call the method
|
|
await wrapper.recreate_prisma_client("postgresql://new:new@localhost:5432/new")
|
|
|
|
# Verify that disconnect was called
|
|
mock_prisma.disconnect.assert_called_once()
|
|
|
|
# Verify that the new client replaced the original
|
|
assert wrapper._original_prisma != mock_prisma
|
|
assert hasattr(wrapper._original_prisma, 'connect') |