mirror of
https://github.com/tiennm99/litellm.git
synced 2026-07-04 17:08:48 +00:00
0f449bf038
* Fixes issue with team_endpoints on member budget update * refactored location of budget membership fix * added test for _upsert_budget_membership func
133 lines
4.2 KiB
Python
133 lines
4.2 KiB
Python
# tests/test_budget_endpoints.py
|
|
|
|
import os
|
|
import sys
|
|
import types
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from fastapi.testclient import TestClient
|
|
|
|
import litellm.proxy.proxy_server as ps
|
|
from litellm.proxy.proxy_server import app
|
|
from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles, CommonProxyErrors
|
|
|
|
import litellm.proxy.management_endpoints.budget_management_endpoints as bm
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../../")
|
|
) # Adds the parent directory to the system path
|
|
|
|
|
|
@pytest.fixture
|
|
def client_and_mocks(monkeypatch):
|
|
# Setup MagicMock Prisma
|
|
mock_prisma = MagicMock()
|
|
mock_table = MagicMock()
|
|
mock_table.create = AsyncMock(side_effect=lambda *, data: data)
|
|
mock_table.update = AsyncMock(side_effect=lambda *, where, data: {**where, **data})
|
|
|
|
mock_prisma.db = types.SimpleNamespace(
|
|
litellm_budgettable = mock_table,
|
|
litellm_dailyspend = mock_table,
|
|
)
|
|
|
|
# Monkeypatch Mocked Prisma client into the server module
|
|
monkeypatch.setattr(ps, "prisma_client", mock_prisma)
|
|
|
|
# override returned auth user
|
|
fake_user = UserAPIKeyAuth(
|
|
user_id="test_user",
|
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
|
)
|
|
app.dependency_overrides[ps.user_api_key_auth] = lambda: fake_user
|
|
|
|
client = TestClient(app)
|
|
|
|
yield client, mock_prisma, mock_table
|
|
|
|
# teardown
|
|
app.dependency_overrides.clear()
|
|
monkeypatch.setattr(ps, "prisma_client", ps.prisma_client)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_new_budget_success(client_and_mocks):
|
|
client, _, mock_table = client_and_mocks
|
|
|
|
# Call /budget/new endpoint
|
|
payload = {
|
|
"budget_id": "budget_123",
|
|
"max_budget": 42.0,
|
|
"budget_duration": "30d",
|
|
}
|
|
resp = client.post("/budget/new", json=payload)
|
|
assert resp.status_code == 200, resp.text
|
|
|
|
body = resp.json()
|
|
assert body["budget_id"] == payload["budget_id"]
|
|
assert body["max_budget"] == payload["max_budget"]
|
|
assert body["budget_duration"] == payload["budget_duration"]
|
|
assert body["created_by"] == "test_user"
|
|
assert body["updated_by"] == "test_user"
|
|
|
|
mock_table.create.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_new_budget_db_not_connected(client_and_mocks, monkeypatch):
|
|
client, mock_prisma, mock_table = client_and_mocks
|
|
|
|
# override the prisma_client that the handler imports at runtime
|
|
import litellm.proxy.proxy_server as ps
|
|
monkeypatch.setattr(ps, "prisma_client", None)
|
|
|
|
# Call /budget/new endpoint
|
|
resp = client.post("/budget/new", json={"budget_id": "no_db", "max_budget": 1.0})
|
|
assert resp.status_code == 500
|
|
detail = resp.json()["detail"]
|
|
assert detail["error"] == CommonProxyErrors.db_not_connected_error.value
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_budget_success(client_and_mocks, monkeypatch):
|
|
client, mock_prisma, mock_table = client_and_mocks
|
|
|
|
payload = {
|
|
"budget_id": "budget_456",
|
|
"max_budget": 99.0,
|
|
"soft_budget": 50.0,
|
|
}
|
|
resp = client.post("/budget/update", json=payload)
|
|
assert resp.status_code == 200, resp.text
|
|
body = resp.json()
|
|
assert body["budget_id"] == payload["budget_id"]
|
|
assert body["max_budget"] == payload["max_budget"]
|
|
assert body["soft_budget"] == payload["soft_budget"]
|
|
assert body["updated_by"] == "test_user"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_budget_missing_id(client_and_mocks, monkeypatch):
|
|
client, mock_prisma, mock_table = client_and_mocks
|
|
|
|
payload = {"max_budget": 10.0}
|
|
resp = client.post("/budget/update", json=payload)
|
|
assert resp.status_code == 400, resp.text
|
|
detail = resp.json()["detail"]
|
|
assert detail["error"] == "budget_id is required"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_budget_db_not_connected(client_and_mocks, monkeypatch):
|
|
client, mock_prisma, mock_table = client_and_mocks
|
|
|
|
# override the prisma_client that the handler imports at runtime
|
|
import litellm.proxy.proxy_server as ps
|
|
monkeypatch.setattr(ps, "prisma_client", None)
|
|
|
|
payload = {"budget_id": "any", "max_budget": 1.0}
|
|
resp = client.post("/budget/update", json=payload)
|
|
assert resp.status_code == 500
|
|
detail = resp.json()["detail"]
|
|
assert detail["error"] == CommonProxyErrors.db_not_connected_error.value
|