feat(litellm): add models and repository layers (#29686)

This commit is contained in:
Yassin Kortam
2026-06-06 20:59:33 -07:00
committed by GitHub
parent 118176f21a
commit 5e2db7eee4
124 changed files with 7846 additions and 1850 deletions
+2
View File
@@ -28,6 +28,8 @@ jobs:
tests/test_litellm/completion_extras
tests/test_litellm/containers
tests/test_litellm/experimental_mcp_client
tests/test_litellm/models
tests/test_litellm/repositories
tests/test_litellm/images
tests/test_litellm/interactions
tests/test_litellm/passthrough
+18
View File
@@ -240,6 +240,24 @@ graph LR
7. `DBSpendUpdateWriter.update_database()` queues spend increments to Redis
8. Background job `update_spend` flushes queued spend to PostgreSQL every 60s
### Data Access Layer (Models & Repositories)
Database entities and the operations on them live in two packages at the root of `litellm/` so both the gateway (`proxy/`) and the SDK can use them without importing proxy internals:
- `litellm/models/` holds the canonical Pydantic definitions for every persisted entity (`LiteLLM_VerificationToken`, `LiteLLM_TeamTable`, `LiteLLM_UserTable`, etc.). `proxy/_types.py` re-exports these for backwards compatibility, so existing imports keep working.
- `litellm/repositories/` holds the data-access layer. `BaseRepository[T]` provides the generic CRUD (`find_by_id`, `find_many`, `create`, `update`, `delete`, `count`, `exists`); entity repositories such as `VerificationTokenRepository`, `TeamRepository`, and `UserRepository` add domain-specific queries and writes on top of it.
Conventions to follow when touching this layer:
| Concern | How it's handled |
|---------|------------------|
| JSON columns | Prisma `Json` columns are stored as JSON strings. Repositories `json.dumps()` on write and `json.loads()` on read (see `_to_model` and the `_build_*_data` helpers). |
| Archive-then-delete | `delete_team` / `delete_token` copy the row into the `LiteLLM_Deleted*` table and delete the original inside a single `prisma_client.db.tx()` transaction. Archive payloads are built explicitly so only columns that exist on the archive table are written. |
| Column vs. field names | Where a model field differs from its DB column (for example `org_id` maps to the `organization_id` column), the repository translates in both directions rather than relying on Pydantic to guess. |
| Array mutations | Adds use Prisma's atomic `push` (`add_member`, `add_admin`, `add_models`) to avoid read-modify-write races. Removals fall back to read-modify-write because Prisma has no atomic array remove. |
To add a new entity, define the model under `litellm/models/`, re-export it from `proxy/_types.py` if existing code imports it from there, and add a repository under `litellm/repositories/` (subclass `BaseRepository` for plain CRUD, or add bespoke methods when the entity needs encryption, archiving, or atomic array updates). Mirror the tests in `tests/test_litellm/repositories/`.
---
## 2. SDK Request Flow
@@ -37,6 +37,8 @@ from litellm.proxy._types import (
VirtualKeyEvent,
WebhookEvent,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
from litellm.types.integrations.slack_alerting import *
from ..email_templates.templates import *
@@ -1231,7 +1233,7 @@ Model Info:
and recipient_user_id is not None
and prisma_client is not None
):
user_row = await prisma_client.db.litellm_usertable.find_unique(
user_row = await UserRepository(prisma_client).table.find_unique(
where={"user_id": recipient_user_id}
)
@@ -1263,7 +1265,7 @@ Model Info:
team_id = webhook_event.team_id
team_name = "Default Team"
if team_id is not None and prisma_client is not None:
team_row = await prisma_client.db.litellm_teamtable.find_unique(
team_row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
if team_row is not None:
+2 -1
View File
@@ -7,6 +7,7 @@ from typing import List, Optional
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import WebhookEvent
from litellm.repositories.team_repository import TeamRepository
# we use this for the email header, please send a test email if you change this. verify it looks good on email
LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
@@ -24,7 +25,7 @@ async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
if prisma_client is None:
raise Exception("Not connected to DB!")
team_row = await prisma_client.db.litellm_teamtable.find_unique(
team_row = await TeamRepository(prisma_client).table.find_unique(
where={
"team_id": team_id,
}
+12 -9
View File
@@ -29,13 +29,13 @@ from litellm.exceptions import (
validate_rate_limit_type,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.prometheus_helpers.bounded_prometheus_series_tracker import (
BoundedPrometheusSeriesTracker,
)
from litellm.integrations.prometheus_helpers import (
PrometheusLabelFactoryContext,
_get_cached_end_user_id_for_cost_tracking,
)
from litellm.integrations.prometheus_helpers.bounded_prometheus_series_tracker import (
BoundedPrometheusSeriesTracker,
)
from litellm.litellm_core_utils.core_helpers import (
get_litellm_metadata_from_kwargs,
get_metadata_variable_name_from_kwargs,
@@ -46,6 +46,9 @@ from litellm.proxy._types import (
LiteLLM_UserTable,
UserAPIKeyAuth,
)
from litellm.repositories.organization_repository import OrganizationRepository
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
from litellm.types.integrations.prometheus import *
from litellm.types.integrations.prometheus import (
_sanitize_prometheus_label_name,
@@ -3278,12 +3281,12 @@ class PrometheusLogger(CustomLogger):
page_size: int, page: int
) -> Tuple[List[LiteLLM_UserTable], Optional[int]]:
skip = (page - 1) * page_size
users = await prisma_client.db.litellm_usertable.find_many(
users = await UserRepository(prisma_client).table.find_many(
skip=skip,
take=page_size,
order={"created_at": "desc"},
)
total_count = await prisma_client.db.litellm_usertable.count()
total_count = await UserRepository(prisma_client).table.count()
return users, total_count
await self._initialize_budget_metrics(
@@ -3306,13 +3309,13 @@ class PrometheusLogger(CustomLogger):
async def fetch_orgs(page_size: int, page: int) -> Tuple[list, Optional[int]]:
skip = (page - 1) * page_size
orgs = await prisma_client.db.litellm_organizationtable.find_many(
orgs = await OrganizationRepository(prisma_client).table.find_many(
skip=skip,
take=page_size,
order={"created_at": "desc"},
include={"litellm_budget_table": True},
)
total_count = await prisma_client.db.litellm_organizationtable.count()
total_count = await OrganizationRepository(prisma_client).table.count()
return orgs, total_count
await self._initialize_budget_metrics(
@@ -3380,14 +3383,14 @@ class PrometheusLogger(CustomLogger):
try:
# Get total user count
total_users = await prisma_client.db.litellm_usertable.count()
total_users = await UserRepository(prisma_client).table.count()
self.litellm_total_users_metric.set(total_users)
verbose_logger.debug(
f"Prometheus: set litellm_total_users to {total_users}"
)
# Get total team count
total_teams = await prisma_client.db.litellm_teamtable.count()
total_teams = await TeamRepository(prisma_client).table.count()
self.litellm_teams_count_metric.set(total_teams)
verbose_logger.debug(
f"Prometheus: set litellm_teams_count to {total_teams}"
+5 -4
View File
@@ -17,6 +17,7 @@ from litellm.proxy.common_utils.resource_ownership import (
is_proxy_admin,
user_can_access_resource_owner,
)
from litellm.repositories.table_repositories import SkillsRepository
# Skills are looked up on every chat completion that has skills enabled
# (`SkillsInjectionHook` calls ``fetch_skill_from_db``). 60s LRU/TTL cache
@@ -107,7 +108,7 @@ class LiteLLMSkillsHandler:
f"LiteLLMSkillsHandler: Creating skill {skill_id} with title={data.display_title}"
)
new_skill = await prisma_client.db.litellm_skillstable.create(data=skill_data)
new_skill = await SkillsRepository(prisma_client).table.create(data=skill_data)
return _prisma_skill_to_litellm(new_skill)
@staticmethod
@@ -133,7 +134,7 @@ class LiteLLMSkillsHandler:
return []
find_many_kwargs["where"] = {"created_by": {"in": owner_scopes}}
skills = await prisma_client.db.litellm_skillstable.find_many(
skills = await SkillsRepository(prisma_client).table.find_many(
**find_many_kwargs
)
return [_prisma_skill_to_litellm(s) for s in skills]
@@ -150,7 +151,7 @@ class LiteLLMSkillsHandler:
return cached
prisma_client = await LiteLLMSkillsHandler._get_prisma_client()
skill = await prisma_client.db.litellm_skillstable.find_unique(
skill = await SkillsRepository(prisma_client).table.find_unique(
where={"skill_id": skill_id}
)
_SKILL_CACHE.set_cache(
@@ -189,7 +190,7 @@ class LiteLLMSkillsHandler:
):
raise ValueError(f"Skill not found: {skill_id}")
await prisma_client.db.litellm_skillstable.delete(where={"skill_id": skill_id})
await SkillsRepository(prisma_client).table.delete(where={"skill_id": skill_id})
_SKILL_CACHE.set_cache(skill_id, _NEGATIVE_SKILL_SENTINEL)
return {"id": skill_id, "type": "skill_deleted"}
+66
View File
@@ -0,0 +1,66 @@
"""
Domain models for LiteLLM backend.
"""
from litellm.models.access_group import LiteLLM_AccessGroupTable
from litellm.models.budget import (
LiteLLM_BudgetTable,
LiteLLM_BudgetTableFull,
LiteLLM_TeamMemberTable,
)
from litellm.models.config import LiteLLM_Config
from litellm.models.credentials import (
CreateCredentialItem,
CredentialBase,
CredentialItem,
)
from litellm.models.end_user import LiteLLM_EndUserTable
from litellm.models.managed_files import (
LiteLLM_ManagedFileTable,
LiteLLM_ManagedObjectTable,
LiteLLM_ManagedVectorStoresTable,
LiteLLM_ManagedVectorStoreTable,
)
from litellm.models.mcp_server import LiteLLM_MCPServerTable
from litellm.models.model import LiteLLM_ProxyModelTable
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
from litellm.models.organization import LiteLLM_OrganizationTable
from litellm.models.organization_membership import LiteLLM_OrganizationMembershipTable
from litellm.models.project import LiteLLM_ProjectTable
from litellm.models.skills import LiteLLM_SkillsTable
from litellm.models.spend_logs import LiteLLM_ErrorLogs, LiteLLM_SpendLogs
from litellm.models.tag import LiteLLM_TagTable
from litellm.models.team import LiteLLM_TeamTable
from litellm.models.team_membership import LiteLLM_TeamMembership
from litellm.models.user import LiteLLM_UserTable
from litellm.models.verification_token import LiteLLM_VerificationToken
__all__ = [
"LiteLLM_AccessGroupTable",
"LiteLLM_BudgetTable",
"LiteLLM_BudgetTableFull",
"LiteLLM_TeamMemberTable",
"LiteLLM_Config",
"CredentialBase",
"CredentialItem",
"CreateCredentialItem",
"LiteLLM_EndUserTable",
"LiteLLM_ManagedFileTable",
"LiteLLM_ManagedObjectTable",
"LiteLLM_ManagedVectorStoreTable",
"LiteLLM_ManagedVectorStoresTable",
"LiteLLM_MCPServerTable",
"LiteLLM_ProxyModelTable",
"LiteLLM_ObjectPermissionTable",
"LiteLLM_OrganizationTable",
"LiteLLM_OrganizationMembershipTable",
"LiteLLM_ProjectTable",
"LiteLLM_SkillsTable",
"LiteLLM_ErrorLogs",
"LiteLLM_SpendLogs",
"LiteLLM_TagTable",
"LiteLLM_TeamTable",
"LiteLLM_TeamMembership",
"LiteLLM_UserTable",
"LiteLLM_VerificationToken",
]
+26
View File
@@ -0,0 +1,26 @@
"""
Access group table model.
Canonical definition for ``litellm_accessgrouptable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from datetime import datetime
from typing import List, Optional
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_AccessGroupTable(LiteLLMPydanticObjectBase):
access_group_id: str
access_group_name: str
description: Optional[str] = None
access_model_names: List[str] = []
access_mcp_server_ids: List[str] = []
access_agent_ids: List[str] = []
assigned_team_ids: List[str] = []
assigned_key_ids: List[str] = []
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
updated_by: Optional[str] = None
+38
View File
@@ -0,0 +1,38 @@
"""
Base model class for domain models.
"""
from datetime import datetime
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict
class DomainModel(BaseModel):
"""Base class for all domain models."""
model_config = ConfigDict(
from_attributes=True,
protected_namespaces=(),
extra="ignore",
)
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@classmethod
def from_db_record(cls, record: Any) -> "DomainModel":
"""Create a domain model from a database record."""
if record is None:
raise ValueError("Cannot create domain model from None record")
if isinstance(record, dict):
return cls(**record)
if hasattr(record, "model_dump") and callable(record.model_dump):
return cls(**record.model_dump())
if hasattr(record, "dict") and callable(record.dict):
return cls(**record.dict())
return cls(**dict(record))
def to_db_dict(self, exclude_unset: bool = False) -> Dict[str, Any]:
"""Convert domain model to a dictionary for database operations."""
return self.model_dump(exclude_none=True, exclude_unset=exclude_unset)
+56
View File
@@ -0,0 +1,56 @@
"""
Budget table model.
Canonical definition for ``litellm_budgettable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from datetime import datetime
from typing import List, Optional
from pydantic import ConfigDict
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase):
"""Represents user-controllable params for a LiteLLM_BudgetTable record.
Budget-write paths use `model_fields.keys()` on this class as an allowlist
for user input. Keep server-managed fields (e.g. `budget_reset_at`) on
`LiteLLM_BudgetTableFull` so they aren't user-settable.
"""
budget_id: Optional[str] = None
soft_budget: Optional[float] = None
max_budget: Optional[float] = None
max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None
allowed_models: Optional[List[str]] = (
None # per-member model scope; empty = inherit team models
)
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable):
"""LiteLLM_BudgetTable + server-managed fields returned on API responses."""
budget_reset_at: Optional[datetime] = None
created_at: datetime
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
"""
Used to track spend of a user_id within a team_id
"""
spend: Optional[float] = None
user_id: Optional[str] = None
team_id: Optional[str] = None
budget_id: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
+15
View File
@@ -0,0 +1,15 @@
"""
Config table model.
Canonical definition for ``litellm_config``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from typing import Dict
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_Config(LiteLLMPydanticObjectBase):
param_name: str
param_value: Dict
+31
View File
@@ -0,0 +1,31 @@
"""
Credential table models.
These are the canonical credential types for the proxy. They live in the model
layer; ``litellm.types.utils`` re-exports them for backwards compatibility.
"""
from typing import Optional
from pydantic import BaseModel, model_validator
class CredentialBase(BaseModel):
credential_name: str
credential_info: dict
class CredentialItem(CredentialBase):
credential_values: dict
class CreateCredentialItem(CredentialBase):
credential_values: Optional[dict] = None
model_id: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_credential_params(cls, values):
if not values.get("credential_values") and not values.get("model_id"):
raise ValueError("Either credential_values or model_id must be set")
return values
+35
View File
@@ -0,0 +1,35 @@
"""
End-user table model.
Canonical definition for ``litellm_endusertable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from typing import Literal, Optional
from pydantic import ConfigDict, model_validator
from litellm.models.budget import LiteLLM_BudgetTable
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase):
user_id: str
blocked: bool
alias: Optional[str] = None
spend: float = 0.0
allowed_model_region: Optional[Literal["eu", "us"]] = None
default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
object_permission_id: Optional[str] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
return values
model_config = ConfigDict(protected_namespaces=())
+62
View File
@@ -0,0 +1,62 @@
"""
Managed file, object, and vector store table models.
Canonical definitions for the ``litellm_managed*`` tables. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Union
from litellm.types.llms.base import LiteLLMPydanticObjectBase
from litellm.types.llms.openai import OpenAIFileObject, ResponsesAPIResponse
from litellm.types.utils import LiteLLMBatch, LiteLLMFineTuningJob
class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase):
unified_file_id: str
file_object: Optional[OpenAIFileObject] = None
model_mappings: Dict[str, str]
flat_model_file_ids: List[str]
created_by: Optional[str] = None
team_id: Optional[str] = None
updated_by: Optional[str] = None
storage_backend: Optional[str] = None
storage_url: Optional[str] = None
class LiteLLM_ManagedObjectTable(LiteLLMPydanticObjectBase):
unified_object_id: str
model_object_id: str
file_purpose: Literal["batch", "fine-tune", "response", "container"]
file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob, ResponsesAPIResponse]
created_by: Optional[str] = None
team_id: Optional[str] = None
class LiteLLM_ManagedVectorStoreTable(LiteLLMPydanticObjectBase):
"""Table for managing vector stores with target_model_names support."""
unified_resource_id: str
resource_object: Optional[Any] = None
model_mappings: Dict[str, str]
flat_model_resource_ids: List[str]
created_by: Optional[str] = None
team_id: Optional[str] = None
updated_by: Optional[str] = None
storage_backend: Optional[str] = None
storage_url: Optional[str] = None
class LiteLLM_ManagedVectorStoresTable(LiteLLMPydanticObjectBase):
vector_store_id: str
custom_llm_provider: str
vector_store_name: Optional[str]
vector_store_description: Optional[str]
vector_store_metadata: Optional[Dict[str, Any]]
created_at: Optional[datetime]
updated_at: Optional[datetime]
litellm_credential_name: Optional[str]
litellm_params: Optional[Dict[str, Any]]
team_id: Optional[str]
user_id: Optional[str]
+103
View File
@@ -0,0 +1,103 @@
"""
MCP server table model.
Canonical definition for ``litellm_mcpservertable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
import enum
from datetime import datetime
from typing import Dict, List, Literal, Optional
from pydantic import Field
from litellm.types.llms.base import LiteLLMPydanticObjectBase
from litellm.types.mcp import MCPAuthType, MCPCredentials, MCPTransportType
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
class MCPEnvVarScope(str, enum.Enum):
"""Scope for an MCP server environment variable.
- ``global``: value is provided by the admin and used for all users.
- ``user``: each user must provide their own value via the per-user
env-var endpoint. The admin-supplied ``value`` is treated as a
placeholder/hint and is not used at request time.
"""
global_ = "global"
user = "user"
class MCPEnvVar(LiteLLMPydanticObjectBase):
"""One environment variable for an MCP server.
Variables can be interpolated into ``static_headers`` using ``${NAME}``
syntax. ``scope=global`` values are stored on the server. ``scope=user``
values are stored per-user in ``LiteLLM_MCPUserEnvVars`` and supplied by
each user.
"""
name: str
value: str = ""
scope: MCPEnvVarScope = MCPEnvVarScope.global_
description: Optional[str] = None
class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase):
"""Represents a LiteLLM_MCPServerTable record"""
server_id: str
server_name: Optional[str] = None
alias: Optional[str] = None
description: Optional[str] = None
url: Optional[str] = None
spec_path: Optional[str] = None
transport: MCPTransportType
auth_type: Optional[MCPAuthType] = None
credentials: Optional[MCPCredentials] = None
instructions: Optional[str] = None
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
updated_by: Optional[str] = None
teams: List[Dict[str, Optional[str]]] = Field(default_factory=list)
mcp_access_groups: List[str] = Field(default_factory=list)
allowed_tools: List[str] = Field(default_factory=list)
tool_name_to_display_name: Optional[Dict[str, str]] = None
tool_name_to_description: Optional[Dict[str, str]] = None
extra_headers: List[str] = Field(default_factory=list)
mcp_info: Optional[MCPInfo] = None
static_headers: Optional[Dict[str, str]] = None
env_vars: Optional[List[MCPEnvVar]] = None
status: Optional[Literal["healthy", "unhealthy", "unknown"]] = Field(
default="unknown",
description="Health status: 'healthy', 'unhealthy', 'unknown'",
)
last_health_check: Optional[datetime] = None
health_check_error: Optional[str] = None
command: Optional[str] = None
args: List[str] = Field(default_factory=list)
env: Dict[str, str] = Field(default_factory=dict)
authorization_url: Optional[str] = None
token_url: Optional[str] = None
registration_url: Optional[str] = None
oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = None
allow_all_keys: bool = False
available_on_public_internet: bool = True
delegate_auth_to_upstream: bool = False
oauth_passthrough: bool = False
is_byok: bool = False
byok_description: List[str] = Field(default_factory=list)
byok_api_key_help_url: Optional[str] = None
has_user_credential: Optional[bool] = None
source_url: Optional[str] = None
timeout: Optional[float] = None
approval_status: Optional[str] = Field(
default="active",
description="Approval status: 'pending_review', 'active', 'rejected'",
)
submitted_by: Optional[str] = None
submitted_at: Optional[datetime] = None
reviewed_at: Optional[datetime] = None
review_notes: Optional[str] = None
+59
View File
@@ -0,0 +1,59 @@
"""
Proxy model table model.
Canonical definition for ``litellm_proxymodeltable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
import json
from datetime import datetime
from typing import Optional
from pydantic import ConfigDict, model_validator
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_ProxyModelTable(LiteLLMPydanticObjectBase):
model_id: str
model_name: str
litellm_params: dict
model_info: Optional[dict] = None
blocked: bool = False
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
updated_by: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
@model_validator(mode="before")
@classmethod
def check_potential_json_str(cls, values):
if isinstance(values.get("litellm_params"), str):
try:
values["litellm_params"] = json.loads(values["litellm_params"])
except json.JSONDecodeError:
pass
if isinstance(values.get("model_info"), str):
try:
values["model_info"] = json.loads(values["model_info"])
except json.JSONDecodeError:
pass
return values
@property
def is_blocked(self) -> bool:
return self.blocked
@property
def team_id(self) -> Optional[str]:
if self.model_info:
return self.model_info.get("team_id")
return None
@property
def team_public_model_name(self) -> Optional[str]:
if self.model_info:
return self.model_info.get("team_public_model_name")
return None
+26
View File
@@ -0,0 +1,26 @@
"""
Object permission table model.
Canonical definition for ``litellm_objectpermissiontable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from typing import Dict, List, Optional
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_ObjectPermissionTable(LiteLLMPydanticObjectBase):
"""Represents a LiteLLM_ObjectPermissionTable record"""
object_permission_id: str
mcp_servers: Optional[List[str]] = []
mcp_access_groups: Optional[List[str]] = []
mcp_tool_permissions: Optional[Dict[str, List[str]]] = None
vector_stores: Optional[List[str]] = []
agents: Optional[List[str]] = []
agent_access_groups: Optional[List[str]] = []
models: Optional[List[str]] = []
mcp_toolsets: Optional[List[str]] = None
blocked_tools: Optional[List[str]] = []
search_tools: Optional[List[str]] = []
+31
View File
@@ -0,0 +1,31 @@
"""
Organization table model.
Canonical definition for ``litellm_organizationtable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from typing import List, Optional
from litellm.models.budget import LiteLLM_BudgetTable
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
from litellm.models.user import LiteLLM_UserTable
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_OrganizationTable(LiteLLMPydanticObjectBase):
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
organization_id: Optional[str] = None
organization_alias: Optional[str] = None
budget_id: str
spend: float = 0.0
metadata: Optional[dict] = None
models: List[str] = []
model_spend: Optional[dict] = {}
created_by: str
updated_by: str
users: Optional[List[LiteLLM_UserTable]] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
object_permission_id: Optional[str] = None
+40
View File
@@ -0,0 +1,40 @@
"""
Organization membership table model.
Canonical definition for ``litellm_organizationmembership``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from datetime import datetime
from typing import Any, Optional
from pydantic import ConfigDict, model_validator
from litellm.models.budget import LiteLLM_BudgetTable
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
"""Tracks which organizations a user belongs to and their spend within it."""
user_id: str
organization_id: str
user_role: Optional[str] = None
spend: float = 0.0
budget_id: Optional[str] = None
created_at: datetime
updated_at: datetime
user: Optional[Any] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
user_email: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
@model_validator(mode="after")
def populate_user_email(self) -> "LiteLLM_OrganizationMembershipTable":
if self.user_email is None and self.user is not None:
if isinstance(self.user, dict):
self.user_email = self.user.get("user_email")
else:
self.user_email = getattr(self.user, "user_email", None)
return self
+41
View File
@@ -0,0 +1,41 @@
"""
Project table model.
Canonical definition for ``litellm_projecttable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from datetime import datetime
from typing import List, Optional
from litellm.models.budget import LiteLLM_BudgetTable
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_ProjectTable(LiteLLMPydanticObjectBase):
"""Database model representation for project"""
project_id: str
project_alias: Optional[str] = None
description: Optional[str] = None
team_id: Optional[str] = None
budget_id: Optional[str] = None
metadata: Optional[dict] = None
models: List[str] = []
spend: float = 0.0
model_spend: Optional[dict] = None
model_rpm_limit: Optional[dict] = None
model_tpm_limit: Optional[dict] = None
blocked: bool = False
object_permission_id: Optional[str] = None
created_by: Optional[str] = None
updated_by: Optional[str] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
@property
def is_blocked(self) -> bool:
return self.blocked
+30
View File
@@ -0,0 +1,30 @@
"""
Skills table model.
Canonical definition for ``litellm_skillstable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from datetime import datetime
from typing import Any, Dict, Optional
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_SkillsTable(LiteLLMPydanticObjectBase):
"""Represents a LiteLLM_SkillsTable record"""
skill_id: str
display_title: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
source: str = "custom"
latest_version: Optional[str] = None
file_content: Optional[bytes] = None
file_name: Optional[str] = None
file_type: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
updated_by: Optional[str] = None
+50
View File
@@ -0,0 +1,50 @@
"""
Spend and error log table models.
Canonical definitions for ``litellm_spendlogs`` and ``litellm_errorlogs``.
Re-exported from ``litellm.proxy._types`` for backwards compatibility.
"""
from datetime import datetime
from typing import Optional, Union
from pydantic import Json
from litellm._uuid import uuid
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_SpendLogs(LiteLLMPydanticObjectBase):
request_id: str
api_key: str
model: Optional[str] = ""
api_base: Optional[str] = ""
call_type: str
spend: Optional[float] = 0.0
total_tokens: Optional[int] = 0
prompt_tokens: Optional[int] = 0
completion_tokens: Optional[int] = 0
startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None]
user: Optional[str] = ""
metadata: Optional[Json] = {}
cache_hit: Optional[str] = "False"
cache_key: Optional[str] = None
request_tags: Optional[Json] = None
requester_ip_address: Optional[str] = None
messages: Optional[Union[str, list, dict]]
response: Optional[Union[str, list, dict]]
class LiteLLM_ErrorLogs(LiteLLMPydanticObjectBase):
request_id: Optional[str] = str(uuid.uuid4())
api_base: Optional[str] = ""
model_group: Optional[str] = ""
litellm_model_name: Optional[str] = ""
model_id: Optional[str] = ""
request_kwargs: Optional[dict] = {}
exception_type: Optional[str] = ""
status_code: Optional[str] = ""
exception_string: Optional[str] = ""
startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None]
+36
View File
@@ -0,0 +1,36 @@
"""
Tag table model.
Canonical definition for ``litellm_tagtable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from datetime import datetime
from typing import List, Optional
from pydantic import model_validator
from litellm.models.budget import LiteLLM_BudgetTable
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_TagTable(LiteLLMPydanticObjectBase):
tag_name: str
description: Optional[str] = None
models: List[str] = []
model_info: Optional[dict] = None
spend: float = 0.0
budget_id: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
if values.get("models") is None:
values.update({"models": []})
return values
+154
View File
@@ -0,0 +1,154 @@
"""
Team table models.
Canonical definitions for ``litellm_teamtable`` (plus the shared Member and
budget-window value types and the team-model alias table). Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
import json
from datetime import datetime
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, model_validator
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class MemberBase(LiteLLMPydanticObjectBase):
user_id: Optional[str] = Field(
default=None,
description="The unique ID of the user to add. Either user_id or user_email must be provided",
)
user_email: Optional[str] = Field(
default=None,
description="The email address of the user to add. Either user_id or user_email must be provided",
)
@model_validator(mode="before")
@classmethod
def check_user_info(cls, values):
if not isinstance(values, dict):
raise ValueError("input needs to be a dictionary")
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
return values
class Member(MemberBase):
role: Literal["admin", "user"] = Field(
description="The role of the user within the team. 'admin' users can manage team settings and members, 'user' is a regular team member"
)
class BudgetLimitEntry(LiteLLMPydanticObjectBase):
"""A single budget window with its own limit and independent reset schedule."""
budget_duration: str
max_budget: float
reset_at: Optional[datetime] = None
class LiteLLM_ModelTable(LiteLLMPydanticObjectBase):
id: Optional[int] = None
model_aliases: Optional[Union[str, dict]] = None
created_by: str
updated_by: str
team: Optional["LiteLLM_TeamTable"] = None
model_config = ConfigDict(protected_namespaces=())
class TeamBase(LiteLLMPydanticObjectBase):
team_alias: Optional[str] = None
team_id: Optional[str] = None
organization_id: Optional[str] = None
admins: list = []
members: list = []
members_with_roles: List[Member] = []
team_member_permissions: Optional[List[str]] = None
metadata: Optional[dict] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
max_budget: Optional[float] = None
soft_budget: Optional[float] = None
budget_duration: Optional[str] = None
budget_limits: Optional[List[BudgetLimitEntry]] = None
models: list = []
blocked: bool = False
router_settings: Optional[dict] = None
access_group_ids: Optional[List[str]] = None
default_team_member_models: Optional[List[str]] = None
class LiteLLM_TeamTable(TeamBase):
team_id: str # type: ignore
spend: Optional[float] = None
max_parallel_requests: Optional[int] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
model_spend: Optional[dict] = {}
model_max_budget: Optional[dict] = {}
policies: Optional[List[str]] = None
allow_team_guardrail_config: Optional[bool] = False
litellm_model_table: Optional[LiteLLM_ModelTable] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
object_permission_id: Optional[str] = None
updated_at: Optional[datetime] = None
created_at: Optional[datetime] = None
model_config = ConfigDict(protected_namespaces=())
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
dict_fields = [
"metadata",
"aliases",
"config",
"permissions",
"model_max_budget",
"model_aliases",
"router_settings",
"budget_limits",
]
if isinstance(values, BaseModel):
values = values.model_dump()
if (
isinstance(values.get("members_with_roles"), dict)
and not values["members_with_roles"]
):
values["members_with_roles"] = []
for field in dict_fields:
value = values.get(field)
if value is not None and isinstance(value, str):
try:
values[field] = json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"Field {field} should be a valid dictionary")
return values
class LiteLLM_TeamTableCachedObj(LiteLLM_TeamTable):
last_refreshed_at: Optional[float] = None
class LiteLLM_DeletedTeamTable(LiteLLM_TeamTable):
"""Audit record for deleted teams; mirrors the team plus deletion metadata."""
id: Optional[str] = None
deleted_at: Optional[datetime] = None
deleted_by: Optional[str] = None
deleted_by_api_key: Optional[str] = None
litellm_changed_by: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
LiteLLM_ModelTable.model_rebuild()
+32
View File
@@ -0,0 +1,32 @@
"""
Team membership table model.
Canonical definition for ``litellm_teammembership``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from typing import Optional, Union
from litellm.models.budget import LiteLLM_BudgetTable, LiteLLM_BudgetTableFull
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase):
user_id: str
team_id: str
budget_id: Optional[str] = None
spend: Optional[float] = 0.0
total_spend: Optional[float] = 0.0
litellm_budget_table: Optional[
Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable]
] = None
def safe_get_team_member_rpm_limit(self) -> Optional[int]:
if self.litellm_budget_table is not None:
return self.litellm_budget_table.rpm_limit
return None
def safe_get_team_member_tpm_limit(self) -> Optional[int]:
if self.litellm_budget_table is not None:
return self.litellm_budget_table.tpm_limit
return None
+70
View File
@@ -0,0 +1,70 @@
"""
User table model.
Canonical definition for ``litellm_usertable``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from datetime import datetime
from typing import Dict, List, Optional
from pydantic import ConfigDict, Field, model_validator
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
from litellm.models.organization_membership import (
LiteLLM_OrganizationMembershipTable,
)
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
user_id: str
user_alias: Optional[str] = None
team_id: Optional[str] = None
sso_user_id: Optional[str] = None
organization_id: Optional[str] = None
object_permission_id: Optional[str] = None
password: Optional[str] = Field(default=None, exclude=True)
teams: List[str] = []
user_role: Optional[str] = None
max_budget: Optional[float] = None
spend: float = 0.0
user_email: Optional[str] = None
models: list = []
metadata: Optional[dict] = None
max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
allowed_cache_controls: List[str] = []
policies: List[str] = []
model_spend: Optional[Dict] = {}
model_max_budget: Optional[Dict] = {}
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
model_config = ConfigDict(protected_namespaces=())
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
if values.get("models") is None:
values.update({"models": []})
if values.get("teams") is None:
values.update({"teams": []})
return values
def is_over_budget(self) -> bool:
if self.max_budget is None:
return False
return self.spend >= self.max_budget
def has_model_access(self, model_name: str) -> bool:
if not self.models:
return True
return model_name in self.models
+74
View File
@@ -0,0 +1,74 @@
"""
Verification token table model.
Canonical definition for ``litellm_verificationtoken``. Re-exported from
``litellm.proxy._types`` for backwards compatibility.
"""
from datetime import datetime
from typing import Dict, List, Optional, Union
from pydantic import ConfigDict
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
from litellm.types.llms.base import LiteLLMPydanticObjectBase
class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
token: Optional[str] = None
key_name: Optional[str] = None
key_alias: Optional[str] = None
spend: float = 0.0
max_budget: Optional[float] = None
expires: Optional[Union[str, datetime]] = None
models: List = []
aliases: Dict = {}
config: Dict = {}
user_id: Optional[str] = None
team_id: Optional[str] = None
agent_id: Optional[str] = None
project_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Dict = {}
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
allowed_cache_controls: Optional[list] = []
allowed_routes: Optional[list] = []
permissions: Dict = {}
model_spend: Dict = {}
model_max_budget: Dict = {}
soft_budget_cooldown: bool = False
blocked: Optional[bool] = None
litellm_budget_table: Optional[dict] = None
budget_id: Optional[str] = None
org_id: Optional[str] = None # org id for a given key
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
updated_by: Optional[str] = None
last_active: Optional[datetime] = None
object_permission_id: Optional[str] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
access_group_ids: Optional[List[str]] = None
rotation_count: Optional[int] = 0
auto_rotate: Optional[bool] = False
rotation_interval: Optional[str] = None
last_rotation_at: Optional[datetime] = None
key_rotation_at: Optional[datetime] = None
router_settings: Optional[dict] = None
budget_limits: Optional[List[dict]] = None
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_DeletedVerificationToken(LiteLLM_VerificationToken):
"""Audit record for deleted keys; mirrors the token plus deletion metadata."""
id: Optional[str] = None
deleted_at: Optional[datetime] = None
deleted_by: Optional[str] = None
deleted_by_api_key: Optional[str] = None
litellm_changed_by: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
@@ -14,8 +14,12 @@ from litellm.proxy._types import (
SpecialHeaders,
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.repositories.table_repositories import (
AgentsRepository,
MCPServerRepository,
)
def _parse_mcp_server_names_from_path(
@@ -1445,7 +1449,7 @@ class MCPRequestHandler:
return None
if object_permission_id is None:
agent_row = await prisma_client.db.litellm_agentstable.find_unique(
agent_row = await AgentsRepository(prisma_client).table.find_unique(
where={"agent_id": agent_id},
)
object_permission_id = (
@@ -1600,7 +1604,7 @@ class MCPRequestHandler:
server_ids: Set[str] = set()
if access_groups and prisma_client is not None:
try:
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many(
mcp_servers = await MCPServerRepository(prisma_client).table.find_many(
where={"mcp_access_groups": {"hasSome": access_groups}}
)
for server in mcp_servers:
+68 -57
View File
@@ -8,6 +8,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.constants import MCP_PER_USER_TOKEN_EXPIRY_BUFFER_SECONDS
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._types import (
LiteLLM_MCPServerTable,
LiteLLM_ObjectPermissionTable,
@@ -25,8 +26,16 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy.utils import PrismaClient
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
from litellm.repositories.table_repositories import (
MCPServerRepository,
MCPUserCredentialsRepository,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.mcp import MCPCredentials
@@ -354,7 +363,7 @@ async def get_all_mcp_servers(
where: Dict[str, Any] = {}
if approval_status is not None:
where["approval_status"] = approval_status
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many(
mcp_servers = await MCPServerRepository(prisma_client).table.find_many(
where=where if where else {}
)
@@ -380,7 +389,9 @@ async def get_mcp_server(
"""
Returns the matching mcp server from the db iff exists
"""
mcp_server = await prisma_client.db.litellm_mcpservertable.find_unique(
mcp_server: Optional[LiteLLM_MCPServerTable] = await MCPServerRepository(
prisma_client
).table.find_unique(
where={
"server_id": server_id,
}
@@ -398,12 +409,12 @@ async def get_mcp_servers(
"""
Returns the matching mcp servers from the db with the server_ids
"""
_mcp_servers: List[LiteLLM_MCPServerTable] = (
await prisma_client.db.litellm_mcpservertable.find_many(
where={
"server_id": {"in": server_ids},
}
)
_mcp_servers: List[LiteLLM_MCPServerTable] = await MCPServerRepository(
prisma_client
).table.find_many(
where={
"server_id": {"in": server_ids},
}
)
final_mcp_servers: List[LiteLLM_MCPServerTable] = []
for _mcp_server in _mcp_servers:
@@ -420,15 +431,15 @@ async def get_mcp_servers_by_verificationtoken(
"""
Returns the mcp servers from the db for the verification token
"""
verification_token_record: LiteLLM_TeamTable = (
await prisma_client.db.litellm_verificationtoken.find_unique(
where={
"token": token,
},
include={
"object_permission": True,
},
)
verification_token_record: LiteLLM_TeamTable = await VerificationTokenRepository(
prisma_client
).table.find_unique(
where={
"token": token,
},
include={
"object_permission": True,
},
)
mcp_servers: Optional[List[str]] = []
@@ -446,15 +457,15 @@ async def get_mcp_servers_by_team(
"""
Returns the mcp servers from the db for the team id
"""
team_record: LiteLLM_TeamTable = (
await prisma_client.db.litellm_teamtable.find_unique(
where={
"team_id": team_id,
},
include={
"object_permission": True,
},
)
team_record: LiteLLM_TeamTable = await TeamRepository(
prisma_client
).table.find_unique(
where={
"team_id": team_id,
},
include={
"object_permission": True,
},
)
mcp_servers: Optional[List[str]] = []
@@ -505,16 +516,16 @@ async def get_objectpermissions_for_mcp_server(
"""
Get all the object permissions records and the associated team and verficiationtoken records that have access to the mcp server
"""
object_permission_records = (
await prisma_client.db.litellm_objectpermissiontable.find_many(
where={
"mcp_servers": {"has": mcp_server_id},
},
include={
"teams": True,
"verification_tokens": True,
},
)
object_permission_records = await ObjectPermissionRepository(
prisma_client
).table.find_many(
where={
"mcp_servers": {"has": mcp_server_id},
},
include={
"teams": True,
"verification_tokens": True,
},
)
return object_permission_records
@@ -526,7 +537,7 @@ async def get_virtualkeys_for_mcp_server(
"""
Get all the virtual keys that have access to the mcp server
"""
virtual_keys = await prisma_client.db.litellm_verificationtoken.find_many(
virtual_keys = await VerificationTokenRepository(prisma_client).table.find_many(
where={
"mcp_servers": {"has": server_id},
},
@@ -564,7 +575,7 @@ async def delete_mcp_server(
Returns the deleted mcp server record if it exists, otherwise None
"""
deleted_server = await prisma_client.db.litellm_mcpservertable.delete(
deleted_server = await MCPServerRepository(prisma_client).table.delete(
where={
"server_id": server_id,
},
@@ -600,7 +611,7 @@ async def create_mcp_server(
data_dict["created_by"] = touched_by
data_dict["updated_by"] = touched_by
new_mcp_server = await prisma_client.db.litellm_mcpservertable.create(
new_mcp_server = await MCPServerRepository(prisma_client).table.create(
data=data_dict # type: ignore
)
@@ -635,7 +646,7 @@ async def update_mcp_server(
"credentials" in data_dict and data_dict["credentials"] is not None
)
if data.auth_type or has_credentials:
existing = await prisma_client.db.litellm_mcpservertable.find_unique(
existing = await MCPServerRepository(prisma_client).table.find_unique(
where={"server_id": data.server_id}
)
@@ -678,7 +689,7 @@ async def update_mcp_server(
# Add audit fields
data_dict["updated_by"] = touched_by
updated_mcp_server = await prisma_client.db.litellm_mcpservertable.update(
updated_mcp_server = await MCPServerRepository(prisma_client).table.update(
where={"server_id": data.server_id}, data=data_dict # type: ignore
)
@@ -691,7 +702,7 @@ async def rotate_mcp_server_credentials_master_key(
):
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
mcp_servers = await MCPServerRepository(prisma_client).table.find_many()
updated = 0
for mcp_server in mcp_servers:
@@ -719,7 +730,7 @@ async def rotate_mcp_server_credentials_master_key(
continue
update_data["updated_by"] = touched_by
await prisma_client.db.litellm_mcpservertable.update(
await MCPServerRepository(prisma_client).table.update(
where={"server_id": mcp_server.server_id},
data=update_data,
)
@@ -781,7 +792,7 @@ async def rotate_mcp_user_credentials_master_key(
under the new master key. Rows that are unreadable under both paths
are logged and skipped so one corrupt row does not abort the rotation.
"""
rows = await prisma_client.db.litellm_mcpusercredentials.find_many()
rows = await MCPUserCredentialsRepository(prisma_client).table.find_many()
rotated = 0
skipped = 0
for row in rows:
@@ -798,7 +809,7 @@ async def rotate_mcp_user_credentials_master_key(
re_encrypted = encrypt_value_helper(
plaintext, new_encryption_key=new_master_key
)
await prisma_client.db.litellm_mcpusercredentials.update(
await MCPUserCredentialsRepository(prisma_client).table.update(
where={
"user_id_server_id": {
"user_id": row.user_id,
@@ -873,7 +884,7 @@ async def store_user_credential(
"""Store a user credential for a BYOK MCP server."""
encoded = encrypt_value_helper(credential)
await prisma_client.db.litellm_mcpusercredentials.upsert(
await MCPUserCredentialsRepository(prisma_client).table.upsert(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
data={
"create": {
@@ -893,7 +904,7 @@ async def get_user_credential(
) -> Optional[str]:
"""Return credential for a user+server pair, or None."""
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
row = await MCPUserCredentialsRepository(prisma_client).table.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
if row is None:
@@ -907,7 +918,7 @@ async def has_user_credential(
server_id: str,
) -> bool:
"""Return True if the user has a stored credential for this server."""
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
row = await MCPUserCredentialsRepository(prisma_client).table.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
return row is not None
@@ -919,7 +930,7 @@ async def delete_user_credential(
server_id: str,
) -> None:
"""Delete the user's stored credential for a BYOK MCP server."""
await prisma_client.db.litellm_mcpusercredentials.delete(
await MCPUserCredentialsRepository(prisma_client).table.delete(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
@@ -966,7 +977,7 @@ async def store_user_oauth_credential(
# Skip the guard when the caller knows the row is already an OAuth2 credential
# (e.g. during token refresh), saving an extra DB round-trip.
if not skip_byok_guard:
existing = await prisma_client.db.litellm_mcpusercredentials.find_unique(
existing = await MCPUserCredentialsRepository(prisma_client).table.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
if (
@@ -984,7 +995,7 @@ async def store_user_oauth_credential(
)
encoded = encrypt_value_helper(json.dumps(payload))
await prisma_client.db.litellm_mcpusercredentials.upsert(
await MCPUserCredentialsRepository(prisma_client).table.upsert(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
data={
"create": {
@@ -1025,7 +1036,7 @@ async def get_user_oauth_credential(
) -> Optional[Dict[str, Any]]:
"""Return the decoded OAuth2 payload dict for a user+server pair, or None."""
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
row = await MCPUserCredentialsRepository(prisma_client).table.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
if row is None:
@@ -1039,7 +1050,7 @@ async def list_user_oauth_credentials(
) -> List[Dict[str, Any]]:
"""Return all OAuth2 credential payloads for a user, tagged with server_id."""
rows = await prisma_client.db.litellm_mcpusercredentials.find_many(
rows = await MCPUserCredentialsRepository(prisma_client).table.find_many(
where={"user_id": user_id}
)
results: List[Dict[str, Any]] = []
@@ -1212,7 +1223,7 @@ async def approve_mcp_server(
) -> LiteLLM_MCPServerTable:
"""Set approval_status=active and record reviewed_at."""
now = datetime.now(timezone.utc)
updated = await prisma_client.db.litellm_mcpservertable.update(
updated = await MCPServerRepository(prisma_client).table.update(
where={"server_id": server_id},
data={
"approval_status": MCPApprovalStatus.active,
@@ -1240,7 +1251,7 @@ async def reject_mcp_server(
}
if review_notes is not None:
data["review_notes"] = review_notes
updated = await prisma_client.db.litellm_mcpservertable.update(
updated = await MCPServerRepository(prisma_client).table.update(
where={"server_id": server_id},
data=data,
)
@@ -1257,7 +1268,7 @@ async def get_mcp_submissions(
along with a summary count breakdown by approval_status.
Mirrors get_guardrail_submissions() from guardrail_endpoints.py.
"""
rows = await prisma_client.db.litellm_mcpservertable.find_many(
rows = await MCPServerRepository(prisma_client).table.find_many(
where={"submitted_at": {"not": None}},
order={"submitted_at": "desc"},
take=500, # safety cap; paginate if needed in a future iteration
@@ -42,8 +42,8 @@ from litellm.constants import (
MCP_TOOL_LISTING_TIMEOUT,
)
from litellm.exceptions import BlockedPiiEntityError, GuardrailRaisedException
from litellm.litellm_core_utils.url_utils import SSRFError, async_safe_get
from litellm.experimental_mcp_client.client import MCPClient, MCPSigV4Auth
from litellm.litellm_core_utils.url_utils import SSRFError, async_safe_get
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
@@ -85,6 +85,7 @@ from litellm.proxy._types import (
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper
from litellm.proxy.utils import ProxyLogging
from litellm.repositories.table_repositories import MCPServerRepository
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.mcp import MCPAuth, MCPStdioConfig
from litellm.types.mcp_server.mcp_server_manager import (
@@ -3817,7 +3818,7 @@ class MCPServerManager:
# Pending/rejected servers are excluded at the DB level so we never load them.
from litellm.proxy._experimental.mcp_server.db import LiteLLM_MCPServerTable
raw_rows = await prisma_client.db.litellm_mcpservertable.find_many(
raw_rows = await MCPServerRepository(prisma_client).table.find_many(
where={
"OR": [
{"approval_status": None},
@@ -4,6 +4,7 @@ from typing import List, Optional
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.proxy.utils import PrismaClient
from litellm.repositories.table_repositories import MCPToolsetRepository
from litellm.types.mcp_server.mcp_toolset import (
MCPToolset,
NewMCPToolsetRequest,
@@ -30,7 +31,7 @@ async def create_mcp_toolset(
data_dict["tools"] = json.dumps(data_dict.get("tools", []))
data_dict["created_by"] = touched_by
data_dict["updated_by"] = touched_by
row = await prisma_client.db.litellm_mcptoolsettable.create(data=data_dict)
row = await MCPToolsetRepository(prisma_client).table.create(data=data_dict)
return _toolset_from_row(row)
@@ -38,7 +39,7 @@ async def get_mcp_toolset(
prisma_client: PrismaClient,
toolset_id: str,
) -> Optional[MCPToolset]:
row = await prisma_client.db.litellm_mcptoolsettable.find_unique(
row = await MCPToolsetRepository(prisma_client).table.find_unique(
where={"toolset_id": toolset_id}
)
if row is None:
@@ -54,7 +55,7 @@ async def list_mcp_toolsets(
where = {}
if toolset_ids is not None:
where = {"toolset_id": {"in": toolset_ids}}
rows = await prisma_client.db.litellm_mcptoolsettable.find_many(where=where)
rows = await MCPToolsetRepository(prisma_client).table.find_many(where=where)
return [_toolset_from_row(r) for r in rows]
except Exception as e:
verbose_proxy_logger.warning(
@@ -69,7 +70,7 @@ async def get_mcp_toolset_by_name(
prisma_client: PrismaClient,
toolset_name: str,
) -> Optional[MCPToolset]:
row = await prisma_client.db.litellm_mcptoolsettable.find_first(
row = await MCPToolsetRepository(prisma_client).table.find_first(
where={"toolset_name": toolset_name}
)
if row is None:
@@ -87,7 +88,7 @@ async def update_mcp_toolset(
data_dict["tools"] = json.dumps(data_dict["tools"])
data_dict["updated_by"] = touched_by
try:
row = await prisma_client.db.litellm_mcptoolsettable.update(
row = await MCPToolsetRepository(prisma_client).table.update(
where={"toolset_id": data.toolset_id},
data=data_dict,
)
@@ -105,7 +106,7 @@ async def delete_mcp_toolset(
toolset_id: str,
) -> Optional[MCPToolset]:
try:
row = await prisma_client.db.litellm_mcptoolsettable.delete(
row = await MCPToolsetRepository(prisma_client).table.delete(
where={"toolset_id": toolset_id}
)
except Exception as e:
+84 -671
View File
@@ -23,8 +23,6 @@ from litellm.litellm_core_utils.initialize_dynamic_callback_params import (
from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIFileObject,
ResponsesAPIResponse,
)
from litellm.types.mcp import (
MCPAuthType,
@@ -41,8 +39,6 @@ from litellm.types.utils import (
EmbeddingResponse,
GenericBudgetConfigType,
ImageResponse,
LiteLLMBatch,
LiteLLMFineTuningJob,
LiteLLMPydanticObjectBase,
ModelResponse,
ProviderField,
@@ -1014,12 +1010,7 @@ class LiteLLM_ObjectPermissionBase(LiteLLMPydanticObjectBase):
search_tools: Optional[List[str]] = None
class BudgetLimitEntry(LiteLLMPydanticObjectBase):
"""A single budget window with its own limit and independent reset schedule."""
budget_duration: str # e.g. "24h", "7d", "30d"
max_budget: float # max spend in USD for this window
reset_at: Optional[datetime] = None # populated at creation/reset time
from litellm.models.team import BudgetLimitEntry as BudgetLimitEntry # noqa: E402
class GenerateRequestBase(LiteLLMPydanticObjectBase):
@@ -1217,40 +1208,10 @@ class KeyRequest(LiteLLMPydanticObjectBase):
return values
class LiteLLM_ModelTable(LiteLLMPydanticObjectBase):
id: Optional[int] = None
model_aliases: Optional[Union[str, dict]] = None # json dump the dict
created_by: str
updated_by: str
team: Optional["LiteLLM_TeamTable"] = None
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_ProxyModelTable(LiteLLMPydanticObjectBase):
model_id: str
model_name: str
litellm_params: dict
model_info: dict
created_at: Optional[datetime] = None
created_by: str
updated_at: Optional[datetime] = None
updated_by: str
@model_validator(mode="before")
@classmethod
def check_potential_json_str(cls, values):
if isinstance(values.get("litellm_params"), str):
try:
values["litellm_params"] = json.loads(values["litellm_params"])
except json.JSONDecodeError:
pass
if isinstance(values.get("model_info"), str):
try:
values["model_info"] = json.loads(values["model_info"])
except json.JSONDecodeError:
pass
return values
from litellm.models.model import ( # noqa: E402
LiteLLM_ProxyModelTable as LiteLLM_ProxyModelTable,
)
from litellm.models.team import LiteLLM_ModelTable as LiteLLM_ModelTable # noqa: E402
# MCP Types
@@ -1265,32 +1226,12 @@ class MCPApprovalStatus(str, enum.Enum):
rejected = "rejected"
class MCPEnvVarScope(str, enum.Enum):
"""Scope for an MCP server environment variable.
- ``global``: value is provided by the admin and used for all users.
- ``user``: each user must provide their own value via the per-user
env-var endpoint. The admin-supplied ``value`` is treated as a
placeholder/hint and is not used at request time.
"""
global_ = "global"
user = "user"
class MCPEnvVar(LiteLLMPydanticObjectBase):
"""One environment variable for an MCP server.
Variables can be interpolated into ``static_headers`` using ``${NAME}``
syntax. ``scope=global`` values are stored on the server. ``scope=user``
values are stored per-user in ``LiteLLM_MCPUserEnvVars`` and supplied by
each user.
"""
name: str
value: str = ""
scope: MCPEnvVarScope = MCPEnvVarScope.global_
description: Optional[str] = None
from litellm.models.mcp_server import ( # noqa: E402
MCPEnvVar as MCPEnvVar,
)
from litellm.models.mcp_server import ( # noqa: E402
MCPEnvVarScope as MCPEnvVarScope,
)
# MCP Proxy Request Types
@@ -1443,66 +1384,9 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
return values
class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase):
"""Represents a LiteLLM_MCPServerTable record"""
server_id: str
server_name: Optional[str] = None
alias: Optional[str] = None
description: Optional[str] = None
url: Optional[str] = None
spec_path: Optional[str] = None
transport: MCPTransportType
auth_type: Optional[MCPAuthType] = None
credentials: Optional[MCPCredentials] = None
instructions: Optional[str] = None
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
updated_by: Optional[str] = None
teams: List[Dict[str, Optional[str]]] = Field(default_factory=list)
mcp_access_groups: List[str] = Field(default_factory=list)
allowed_tools: List[str] = Field(default_factory=list)
tool_name_to_display_name: Optional[Dict[str, str]] = None
tool_name_to_description: Optional[Dict[str, str]] = None
extra_headers: List[str] = Field(default_factory=list)
mcp_info: Optional[MCPInfo] = None
static_headers: Optional[Dict[str, str]] = None
env_vars: Optional[List[MCPEnvVar]] = None
# Health check status
status: Optional[Literal["healthy", "unhealthy", "unknown"]] = Field(
default="unknown",
description="Health status: 'healthy', 'unhealthy', 'unknown'",
)
last_health_check: Optional[datetime] = None
health_check_error: Optional[str] = None
# Stdio-specific fields
command: Optional[str] = None
args: List[str] = Field(default_factory=list)
env: Dict[str, str] = Field(default_factory=dict)
authorization_url: Optional[str] = None
token_url: Optional[str] = None
registration_url: Optional[str] = None
oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = None
allow_all_keys: bool = False
available_on_public_internet: bool = True
delegate_auth_to_upstream: bool = False
oauth_passthrough: bool = False
is_byok: bool = False
byok_description: List[str] = Field(default_factory=list)
byok_api_key_help_url: Optional[str] = None
has_user_credential: Optional[bool] = None
source_url: Optional[str] = None
timeout: Optional[float] = None
# BYOM submission fields
approval_status: Optional[str] = Field(
default="active",
description="Approval status: 'pending_review', 'active', 'rejected'",
)
submitted_by: Optional[str] = None
submitted_at: Optional[datetime] = None
reviewed_at: Optional[datetime] = None
review_notes: Optional[str] = None
from litellm.models.mcp_server import ( # noqa: E402
LiteLLM_MCPServerTable as LiteLLM_MCPServerTable,
)
class MakeMCPServersPublicRequest(LiteLLMPydanticObjectBase):
@@ -1622,23 +1506,9 @@ class UpdateSkillRequest(LiteLLMPydanticObjectBase):
metadata: Optional[Dict[str, Any]] = None
class LiteLLM_SkillsTable(LiteLLMPydanticObjectBase):
"""Represents a LiteLLM_SkillsTable record"""
skill_id: str
display_title: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
source: str = "custom"
latest_version: Optional[str] = None
file_content: Optional[bytes] = None # Binary content of skill files (zip)
file_name: Optional[str] = None # Original filename
file_type: Optional[str] = None # MIME type
metadata: Optional[Dict[str, Any]] = None
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
updated_by: Optional[str] = None
from litellm.models.skills import ( # noqa: E402
LiteLLM_SkillsTable as LiteLLM_SkillsTable,
)
class ListSkillsRequest(LiteLLMPydanticObjectBase):
@@ -1839,33 +1709,8 @@ class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
user_ids: List[str]
class MemberBase(LiteLLMPydanticObjectBase):
user_id: Optional[str] = Field(
default=None,
description="The unique ID of the user to add. Either user_id or user_email must be provided",
)
user_email: Optional[str] = Field(
default=None,
description="The email address of the user to add. Either user_id or user_email must be provided",
)
@model_validator(mode="before")
@classmethod
def check_user_info(cls, values):
if not isinstance(values, dict):
raise ValueError("input needs to be a dictionary")
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
return values
class Member(MemberBase):
role: Literal[
"admin",
"user",
] = Field(
description="The role of the user within the team. 'admin' users can manage team settings and members, 'user' is a regular team member"
)
from litellm.models.team import Member as Member # noqa: E402
from litellm.models.team import MemberBase as MemberBase # noqa: E402
class OrgMember(MemberBase):
@@ -1876,33 +1721,7 @@ class OrgMember(MemberBase):
]
class TeamBase(LiteLLMPydanticObjectBase):
team_alias: Optional[str] = None
team_id: Optional[str] = None
organization_id: Optional[str] = None
admins: list = []
members: list = []
members_with_roles: List[Member] = []
team_member_permissions: Optional[List[str]] = None
metadata: Optional[dict] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
# Budget fields
max_budget: Optional[float] = None
soft_budget: Optional[float] = None
budget_duration: Optional[str] = None
budget_limits: Optional[List[BudgetLimitEntry]] = (
None # multiple concurrent budget windows
)
models: list = []
blocked: bool = False
router_settings: Optional[dict] = None
access_group_ids: Optional[List[str]] = None
default_team_member_models: Optional[List[str]] = (
None # default allowed_models seeded onto new team members
)
from litellm.models.team import TeamBase as TeamBase # noqa: E402
class NewTeamRequest(TeamBase):
@@ -2100,147 +1919,31 @@ class TeamCallbackMetadata(LiteLLMPydanticObjectBase):
return values
class LiteLLM_ObjectPermissionTable(LiteLLMPydanticObjectBase):
"""Represents a LiteLLM_ObjectPermissionTable record"""
object_permission_id: str
mcp_servers: Optional[List[str]] = []
mcp_access_groups: Optional[List[str]] = []
mcp_tool_permissions: Optional[Dict[str, List[str]]] = None
"""
Mapping - server_id -> list of tools
Enforces allowed tools for a specific key/team/organization
{
"1234567890": ["tool_name_1", "tool_name_2"]
}
"""
vector_stores: Optional[List[str]] = []
agents: Optional[List[str]] = []
agent_access_groups: Optional[List[str]] = []
mcp_toolsets: Optional[List[str]] = None
blocked_tools: Optional[List[str]] = []
search_tools: Optional[List[str]] = []
class LiteLLM_TeamTable(TeamBase):
team_id: str # type: ignore
spend: Optional[float] = None
max_parallel_requests: Optional[int] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
litellm_model_table: Optional[LiteLLM_ModelTable] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
updated_at: Optional[datetime] = None
created_at: Optional[datetime] = None
#########################################################
# Object Permission - MCP, Vector Stores etc.
#########################################################
object_permission_id: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
dict_fields = [
"metadata",
"aliases",
"config",
"permissions",
"model_max_budget",
"model_aliases",
"router_settings",
"budget_limits",
]
if isinstance(values, BaseModel):
values = values.model_dump()
if (
isinstance(values.get("members_with_roles"), dict)
and not values["members_with_roles"]
):
values["members_with_roles"] = []
for field in dict_fields:
value = values.get(field)
if value is not None and isinstance(value, str):
try:
values[field] = json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"Field {field} should be a valid dictionary")
return values
class LiteLLM_TeamTableCachedObj(LiteLLM_TeamTable):
last_refreshed_at: Optional[float] = None
class LiteLLM_DeletedTeamTable(LiteLLM_TeamTable):
"""
Recording of deleted teams for audit purposes. Mirrors LiteLLM_TeamTable
plus metadata captured at deletion time.
"""
id: Optional[str] = None
deleted_at: Optional[datetime] = None
deleted_by: Optional[str] = None
deleted_by_api_key: Optional[str] = None
litellm_changed_by: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
from litellm.models.object_permission import ( # noqa: E402
LiteLLM_ObjectPermissionTable as LiteLLM_ObjectPermissionTable,
)
from litellm.models.team import ( # noqa: E402
LiteLLM_DeletedTeamTable as LiteLLM_DeletedTeamTable,
)
from litellm.models.team import LiteLLM_TeamTable as LiteLLM_TeamTable # noqa: E402
from litellm.models.team import ( # noqa: E402
LiteLLM_TeamTableCachedObj as LiteLLM_TeamTableCachedObj,
)
class TeamRequest(LiteLLMPydanticObjectBase):
teams: List[str]
class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase):
"""Represents user-controllable params for a LiteLLM_BudgetTable record.
Budget-write paths use `model_fields.keys()` on this class as an allowlist
for user input. Keep server-managed fields (e.g. `budget_reset_at`) on
`LiteLLM_BudgetTableFull` so they aren't user-settable.
"""
budget_id: Optional[str] = None
soft_budget: Optional[float] = None
max_budget: Optional[float] = None
max_parallel_requests: Optional[int] = None
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None
allowed_models: Optional[List[str]] = (
None # per-member model scope; empty = inherit team models
)
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable):
"""LiteLLM_BudgetTable + server-managed fields returned on API responses."""
budget_reset_at: Optional[datetime] = None
created_at: datetime
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
"""
Used to track spend of a user_id within a team_id
"""
spend: Optional[float] = None
user_id: Optional[str] = None
team_id: Optional[str] = None
budget_id: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
from litellm.models.budget import ( # noqa: E402
LiteLLM_BudgetTable as LiteLLM_BudgetTable,
)
from litellm.models.budget import ( # noqa: E402
LiteLLM_BudgetTableFull as LiteLLM_BudgetTableFull,
)
from litellm.models.budget import ( # noqa: E402
LiteLLM_TeamMemberTable as LiteLLM_TeamMemberTable,
)
class NewOrganizationRequest(LiteLLM_BudgetTable):
@@ -2637,66 +2340,12 @@ class ConfigYAML(LiteLLMPydanticObjectBase):
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
token: Optional[str] = None
key_name: Optional[str] = None
key_alias: Optional[str] = None
spend: float = 0.0
max_budget: Optional[float] = None
expires: Optional[Union[str, datetime]] = None
models: List = []
aliases: Dict = {}
config: Dict = {}
user_id: Optional[str] = None
team_id: Optional[str] = None
agent_id: Optional[str] = None
project_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Dict = {}
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
allowed_cache_controls: Optional[list] = []
allowed_routes: Optional[list] = []
permissions: Dict = {}
model_spend: Dict = {}
model_max_budget: Dict = {}
soft_budget_cooldown: bool = False
blocked: Optional[bool] = None
litellm_budget_table: Optional[dict] = None
org_id: Optional[str] = None # org id for a given key
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
updated_by: Optional[str] = None
last_active: Optional[datetime] = None
object_permission_id: Optional[str] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
access_group_ids: Optional[List[str]] = None
rotation_count: Optional[int] = 0 # Number of times key has been rotated
auto_rotate: Optional[bool] = False # Whether this key should be auto-rotated
rotation_interval: Optional[str] = None # How often to rotate (e.g., "30d", "90d")
last_rotation_at: Optional[datetime] = None # When this key was last rotated
key_rotation_at: Optional[datetime] = None # When this key should next be rotated
router_settings: Optional[dict] = None
budget_limits: Optional[List[dict]] = None # multiple concurrent budget windows
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_DeletedVerificationToken(LiteLLM_VerificationToken):
"""
Recording of deleted keys for audit purposes. Mirrors LiteLLM_VerificationToken
plus metadata captured at deletion time.
"""
id: Optional[str] = None
deleted_at: Optional[datetime] = None
deleted_by: Optional[str] = None
deleted_by_api_key: Optional[str] = None
litellm_changed_by: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
from litellm.models.verification_token import ( # noqa: E402
LiteLLM_DeletedVerificationToken as LiteLLM_DeletedVerificationToken,
)
from litellm.models.verification_token import ( # noqa: E402
LiteLLM_VerificationToken as LiteLLM_VerificationToken,
)
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
@@ -2935,39 +2584,10 @@ class UserInfoV2Response(LiteLLMPydanticObjectBase):
teams: List[str] = [] # Just team IDs, not full team objects
class LiteLLM_Config(LiteLLMPydanticObjectBase):
param_name: str
param_value: Dict
class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
"""
This is the table that track what organizations a user belongs to and users spend within the organization
"""
user_id: str
organization_id: str
user_role: Optional[str] = None
spend: float = 0.0
budget_id: Optional[str] = None
created_at: datetime
updated_at: datetime
user: Optional[Any] = (
None # You might want to replace 'Any' with a more specific type if available
)
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
user_email: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
@model_validator(mode="after")
def populate_user_email(self) -> "LiteLLM_OrganizationMembershipTable":
if self.user_email is None and self.user is not None:
if isinstance(self.user, dict):
self.user_email = self.user.get("user_email")
else:
self.user_email = getattr(self.user, "user_email", None)
return self
from litellm.models.config import LiteLLM_Config as LiteLLM_Config # noqa: E402
from litellm.models.organization_membership import ( # noqa: E402
LiteLLM_OrganizationMembershipTable as LiteLLM_OrganizationMembershipTable,
)
class LiteLLM_OrganizationTableUpdate(LiteLLM_BudgetTable):
@@ -2997,61 +2617,10 @@ class LiteLLM_OrganizationTableUpdate(LiteLLM_BudgetTable):
return values
class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
user_id: str
max_budget: Optional[float] = None
spend: float = 0.0
model_max_budget: Optional[Dict] = {}
model_spend: Optional[Dict] = {}
user_email: Optional[str] = None
user_alias: Optional[str] = None
models: list = []
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
user_role: Optional[str] = None
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
teams: List[str] = []
sso_user_id: Optional[str] = None
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
metadata: Optional[dict] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
if values.get("models") is None:
values.update({"models": []})
if values.get("teams") is None:
values.update({"teams": []})
return values
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_OrganizationTable(LiteLLMPydanticObjectBase):
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
organization_id: Optional[str] = None
organization_alias: Optional[str] = None
budget_id: str
spend: float = 0.0
metadata: Optional[dict] = None
models: List[str]
created_by: str
updated_by: str
users: Optional[List[LiteLLM_UserTable]] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
#########################################################
# Object Permission - MCP, Vector Stores etc.
#########################################################
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
object_permission_id: Optional[str] = None
from litellm.models.organization import ( # noqa: E402
LiteLLM_OrganizationTable as LiteLLM_OrganizationTable,
)
from litellm.models.user import LiteLLM_UserTable as LiteLLM_UserTable # noqa: E402
class LiteLLM_OrganizationTableWithMembers(LiteLLM_OrganizationTable):
@@ -3160,28 +2729,9 @@ class DeleteProjectRequest(LiteLLMPydanticObjectBase):
project_ids: List[str]
class LiteLLM_ProjectTable(LiteLLMPydanticObjectBase):
"""Database model representation for project"""
project_id: str
project_alias: Optional[str] = None
description: Optional[str] = None
team_id: Optional[str] = None
budget_id: Optional[str] = None
metadata: Optional[dict] = None
models: List[str] = []
spend: float = 0.0
model_spend: Optional[dict] = None
model_rpm_limit: Optional[dict] = None
model_tpm_limit: Optional[dict] = None
blocked: bool = False
object_permission_id: Optional[str] = None
created_by: str
updated_by: str
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
from litellm.models.project import ( # noqa: E402
LiteLLM_ProjectTable as LiteLLM_ProjectTable,
)
class NewProjectResponse(LiteLLM_ProjectTable):
@@ -3207,101 +2757,19 @@ class LiteLLM_UserTableWithKeyCount(LiteLLM_UserTable):
key_count: int = 0
class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase):
user_id: str
blocked: bool
alias: Optional[str] = None
spend: float = 0.0
allowed_model_region: Optional[AllowedModelRegion] = None
default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
object_permission_id: Optional[str] = None
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
return values
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_TagTable(LiteLLMPydanticObjectBase):
tag_name: str
description: Optional[str] = None
models: List[str] = []
model_info: Optional[dict] = None
spend: float = 0.0
budget_id: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
if values.get("models") is None:
values.update({"models": []})
return values
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_AccessGroupTable(LiteLLMPydanticObjectBase):
access_group_id: str
access_group_name: str
description: Optional[str] = None
access_model_names: List[str] = []
access_mcp_server_ids: List[str] = []
access_agent_ids: List[str] = []
assigned_team_ids: List[str] = []
assigned_key_ids: List[str] = []
created_at: Optional[datetime] = None
created_by: Optional[str] = None
updated_at: Optional[datetime] = None
updated_by: Optional[str] = None
class LiteLLM_SpendLogs(LiteLLMPydanticObjectBase):
request_id: str
api_key: str
model: Optional[str] = ""
api_base: Optional[str] = ""
call_type: str
spend: Optional[float] = 0.0
total_tokens: Optional[int] = 0
prompt_tokens: Optional[int] = 0
completion_tokens: Optional[int] = 0
startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None]
user: Optional[str] = ""
metadata: Optional[Json] = {}
cache_hit: Optional[str] = "False"
cache_key: Optional[str] = None
request_tags: Optional[Json] = None
requester_ip_address: Optional[str] = None
messages: Optional[Union[str, list, dict]]
response: Optional[Union[str, list, dict]]
class LiteLLM_ErrorLogs(LiteLLMPydanticObjectBase):
request_id: Optional[str] = str(uuid.uuid4())
api_base: Optional[str] = ""
model_group: Optional[str] = ""
litellm_model_name: Optional[str] = ""
model_id: Optional[str] = ""
request_kwargs: Optional[dict] = {}
exception_type: Optional[str] = ""
status_code: Optional[str] = ""
exception_string: Optional[str] = ""
startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None]
from litellm.models.access_group import ( # noqa: E402
LiteLLM_AccessGroupTable as LiteLLM_AccessGroupTable,
)
from litellm.models.end_user import ( # noqa: E402
LiteLLM_EndUserTable as LiteLLM_EndUserTable,
)
from litellm.models.spend_logs import ( # noqa: E402
LiteLLM_ErrorLogs as LiteLLM_ErrorLogs,
)
from litellm.models.spend_logs import ( # noqa: E402
LiteLLM_SpendLogs as LiteLLM_SpendLogs,
)
from litellm.models.tag import LiteLLM_TagTable as LiteLLM_TagTable # noqa: E402
AUDIT_ACTIONS = Literal[
"created", "updated", "deleted", "blocked", "unblocked", "rotated"
@@ -3982,29 +3450,9 @@ class CreatePassThroughEndpoint(LiteLLMPydanticObjectBase):
headers: dict
class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase):
user_id: str
team_id: str
budget_id: Optional[str] = None
spend: Optional[float] = 0.0
total_spend: Optional[float] = 0.0
# Union so Pydantic picks Full when data has server-managed fields
# (/team/info) and Base when callers/tests construct with only
# user-settable fields.
litellm_budget_table: Optional[
Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable]
] = None
def safe_get_team_member_rpm_limit(self) -> Optional[int]:
if self.litellm_budget_table is not None:
return self.litellm_budget_table.rpm_limit
return None
def safe_get_team_member_tpm_limit(self) -> Optional[int]:
if self.litellm_budget_table is not None:
return self.litellm_budget_table.tpm_limit
return None
from litellm.models.team_membership import ( # noqa: E402
LiteLLM_TeamMembership as LiteLLM_TeamMembership,
)
#### Organization / Team Member Requests ####
@@ -4922,39 +4370,18 @@ class ToolDiscoveryQueueItem(TypedDict, total=False):
user_agent: Optional[str] # HTTP User-Agent of the caller
class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase):
unified_file_id: str
file_object: Optional[OpenAIFileObject] = None
model_mappings: Dict[str, str]
flat_model_file_ids: List[str]
created_by: Optional[str] = None
team_id: Optional[str] = None
updated_by: Optional[str] = None
storage_backend: Optional[str] = None
storage_url: Optional[str] = None
class LiteLLM_ManagedObjectTable(LiteLLMPydanticObjectBase):
unified_object_id: str
model_object_id: str
file_purpose: Literal["batch", "fine-tune", "response", "container"]
file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob, ResponsesAPIResponse]
created_by: Optional[str] = None
team_id: Optional[str] = None
class LiteLLM_ManagedVectorStoreTable(LiteLLMPydanticObjectBase):
"""Table for managing vector stores with target_model_names support."""
unified_resource_id: str
resource_object: Optional[Any] = None # VectorStoreCreateResponse
model_mappings: Dict[str, str]
flat_model_resource_ids: List[str]
created_by: Optional[str] = None
team_id: Optional[str] = None
updated_by: Optional[str] = None
storage_backend: Optional[str] = None
storage_url: Optional[str] = None
from litellm.models.managed_files import ( # noqa: E402
LiteLLM_ManagedFileTable as LiteLLM_ManagedFileTable,
)
from litellm.models.managed_files import ( # noqa: E402
LiteLLM_ManagedObjectTable as LiteLLM_ManagedObjectTable,
)
from litellm.models.managed_files import ( # noqa: E402
LiteLLM_ManagedVectorStoresTable as LiteLLM_ManagedVectorStoresTable,
)
from litellm.models.managed_files import ( # noqa: E402
LiteLLM_ManagedVectorStoreTable as LiteLLM_ManagedVectorStoreTable,
)
class EnterpriseLicenseData(TypedDict, total=False):
@@ -4965,20 +4392,6 @@ class EnterpriseLicenseData(TypedDict, total=False):
max_teams: int
class LiteLLM_ManagedVectorStoresTable(LiteLLMPydanticObjectBase):
vector_store_id: str
custom_llm_provider: str
vector_store_name: Optional[str]
vector_store_description: Optional[str]
vector_store_metadata: Optional[Dict[str, Any]]
created_at: Optional[datetime]
updated_at: Optional[datetime]
litellm_credential_name: Optional[str]
litellm_params: Optional[Dict[str, Any]]
team_id: Optional[str]
user_id: Optional[str]
class ResponseLiteLLM_ManagedVectorStore(TypedDict, total=False):
vector_store: LiteLLM_ManagedVectorStoresTable
@@ -9,6 +9,7 @@ from litellm.proxy.management_helpers.object_permission_utils import (
handle_update_object_permission_common,
)
from litellm.proxy.utils import PrismaClient
from litellm.repositories.table_repositories import AgentsRepository
from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest
@@ -174,7 +175,7 @@ class AgentRegistry:
create_data[rate_field] = _val
# Create agent in DB
created_agent = await prisma_client.db.litellm_agentstable.create(
created_agent = await AgentsRepository(prisma_client).table.create(
data=create_data,
include={"object_permission": True},
)
@@ -200,7 +201,7 @@ class AgentRegistry:
Delete an agent from the database
"""
try:
deleted_agent = await prisma_client.db.litellm_agentstable.delete(
deleted_agent = await AgentsRepository(prisma_client).table.delete(
where={"agent_id": agent_id}
)
return dict(deleted_agent)
@@ -229,7 +230,7 @@ class AgentRegistry:
The patched agent
"""
try:
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
existing_agent = await AgentsRepository(prisma_client).table.find_unique(
where={"agent_id": agent_id}
)
if existing_agent is not None:
@@ -282,7 +283,7 @@ class AgentRegistry:
if object_permission_id is not None:
update_data["object_permission_id"] = object_permission_id
# Patch agent in DB
patched_agent = await prisma_client.db.litellm_agentstable.update(
patched_agent = await AgentsRepository(prisma_client).table.update(
where={"agent_id": agent_id},
data={
**update_data,
@@ -368,9 +369,9 @@ class AgentRegistry:
update_data[rate_field] = _val
if agent.get("object_permission") is not None:
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
where={"agent_id": agent_id}
)
existing_agent = await AgentsRepository(
prisma_client
).table.find_unique(where={"agent_id": agent_id})
existing_object_permission_id = (
existing_agent.object_permission_id
if existing_agent is not None
@@ -386,7 +387,7 @@ class AgentRegistry:
update_data["object_permission_id"] = object_permission_id
# Update agent in DB
updated_agent = await prisma_client.db.litellm_agentstable.update(
updated_agent = await AgentsRepository(prisma_client).table.update(
where={"agent_id": agent_id},
data=update_data,
include={"object_permission": True},
@@ -414,7 +415,7 @@ class AgentRegistry:
Get all agents from the database
"""
try:
agents_from_db = await prisma_client.db.litellm_agentstable.find_many(
agents_from_db = await AgentsRepository(prisma_client).table.find_many(
order={"created_at": "desc"},
include={"object_permission": True},
)
@@ -9,11 +9,12 @@ from typing import List, Optional, Set
from litellm._logging import verbose_logger
from litellm.proxy._types import (
UI_TEAM_ID,
LiteLLM_ObjectPermissionTable,
LiteLLM_TeamTable,
UI_TEAM_ID,
UserAPIKeyAuth,
)
from litellm.repositories.table_repositories import AgentsRepository
class AgentRequestHandler:
@@ -298,7 +299,7 @@ class AgentRequestHandler:
agent_ids: Set[str] = set()
if access_groups and prisma_client is not None:
try:
agents = await prisma_client.db.litellm_agentstable.find_many(
agents = await AgentsRepository(prisma_client).table.find_many(
where={"agent_access_groups": {"hasSome": access_groups}}
)
for agent in agents:
+12 -11
View File
@@ -17,6 +17,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import _get_masked_values
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.a2a.agent_card import merge_agent_card
@@ -30,7 +31,6 @@ from litellm.types.agents import (
MakeAgentsPublicRequest,
PatchAgentRequest,
)
from litellm.litellm_core_utils.litellm_logging import _get_masked_values
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.proxy.management_endpoints.common_daily_activity import (
DailySpendMetadata,
@@ -219,7 +219,7 @@ async def get_agents(
if prisma_client is not None:
agent_ids = [agent.agent_id for agent in returned_agents]
if agent_ids:
db_agents = await prisma_client.db.litellm_agentstable.find_many(
db_agents = await AgentsRepository(prisma_client).table.find_many(
where={"agent_id": {"in": agent_ids}},
)
spend_map = {a.agent_id: a.spend for a in db_agents}
@@ -301,6 +301,7 @@ async def get_agents(
from litellm.proxy.agent_endpoints.agent_registry import (
global_agent_registry as AGENT_REGISTRY,
)
from litellm.repositories.table_repositories import AgentsRepository
@router.post(
@@ -471,7 +472,7 @@ async def get_agent_by_id(
try:
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
if agent is None:
agent_row = await prisma_client.db.litellm_agentstable.find_unique(
agent_row = await AgentsRepository(prisma_client).table.find_unique(
where={"agent_id": agent_id},
include={"object_permission": True},
)
@@ -489,7 +490,7 @@ async def get_agent_by_id(
agent = AgentResponse(**agent_dict) # type: ignore
else:
# Agent found in memory — refresh spend from DB
db_row = await prisma_client.db.litellm_agentstable.find_unique(
db_row = await AgentsRepository(prisma_client).table.find_unique(
where={"agent_id": agent_id}
)
if db_row is not None:
@@ -570,7 +571,7 @@ async def update_agent(
try:
# Check if agent exists
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
existing_agent = await AgentsRepository(prisma_client).table.find_unique(
where={"agent_id": agent_id}
)
if existing_agent is not None:
@@ -678,7 +679,7 @@ async def patch_agent(
try:
# Check if agent exists
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
existing_agent = await AgentsRepository(prisma_client).table.find_unique(
where={"agent_id": agent_id}
)
if existing_agent is not None:
@@ -769,7 +770,7 @@ async def delete_agent(
try:
# Check if agent exists
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
existing_agent = await AgentsRepository(prisma_client).table.find_unique(
where={"agent_id": agent_id}
)
if existing_agent is not None:
@@ -859,7 +860,7 @@ async def make_agent_public(
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
if agent is None:
# check if agent exists in DB
agent = await prisma_client.db.litellm_agentstable.find_unique(
agent = await AgentsRepository(prisma_client).table.find_unique(
where={"agent_id": agent_id}
)
if agent is not None:
@@ -982,7 +983,7 @@ async def make_agents_public(
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
if agent is None:
# check if agent exists in DB
agent = await prisma_client.db.litellm_agentstable.find_unique(
agent = await AgentsRepository(prisma_client).table.find_unique(
where={"agent_id": agent_id}
)
if agent is not None:
@@ -1082,7 +1083,7 @@ async def get_agent_daily_activity(
if user_api_key_dict.user_id is None:
permitted_agent_ids = []
else:
owned_records = await prisma_client.db.litellm_agentstable.find_many(
owned_records = await AgentsRepository(prisma_client).table.find_many(
where={"created_by": user_api_key_dict.user_id}
)
permitted_agent_ids = [a.agent_id for a in owned_records]
@@ -1118,7 +1119,7 @@ async def get_agent_daily_activity(
if agent_ids_list:
where_condition["agent_id"] = {"in": list(agent_ids_list)}
agent_records = await prisma_client.db.litellm_agentstable.find_many(
agent_records = await AgentsRepository(prisma_client).table.find_many(
where=where_condition
)
agent_metadata = {
@@ -26,6 +26,7 @@ from fastapi.responses import JSONResponse
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.repositories.table_repositories import ClaudeCodePluginRepository
from litellm.types.proxy.claude_code_endpoints import (
ListPluginsResponse,
PluginListItem,
@@ -71,7 +72,7 @@ async def get_marketplace():
try:
prisma_client = await _get_prisma_client()
plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many(
plugins = await ClaudeCodePluginRepository(prisma_client).table.find_many(
where={"enabled": True}
)
@@ -268,12 +269,12 @@ async def register_plugin(
manifest["namespace"] = request.namespace
# Check if plugin exists
existing = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
existing = await ClaudeCodePluginRepository(prisma_client).table.find_unique(
where={"name": request.name}
)
if existing:
plugin = await prisma_client.db.litellm_claudecodeplugintable.update(
plugin = await ClaudeCodePluginRepository(prisma_client).table.update(
where={"name": request.name},
data={
"version": request.version,
@@ -285,7 +286,7 @@ async def register_plugin(
)
action = "updated"
else:
plugin = await prisma_client.db.litellm_claudecodeplugintable.create(
plugin = await ClaudeCodePluginRepository(prisma_client).table.create(
data={
"name": request.name,
"version": request.version,
@@ -348,7 +349,7 @@ async def list_plugins(
prisma_client = await _get_prisma_client()
where = {"enabled": True} if enabled_only else {}
plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many(
plugins = await ClaudeCodePluginRepository(prisma_client).table.find_many(
where=where
)
@@ -415,7 +416,7 @@ async def get_plugin(
try:
prisma_client = await _get_prisma_client()
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique(
where={"name": plugin_name}
)
@@ -471,7 +472,7 @@ async def enable_plugin(
try:
prisma_client = await _get_prisma_client()
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique(
where={"name": plugin_name}
)
if not plugin:
@@ -480,7 +481,7 @@ async def enable_plugin(
detail={"error": f"Plugin '{plugin_name}' not found"},
)
await prisma_client.db.litellm_claudecodeplugintable.update(
await ClaudeCodePluginRepository(prisma_client).table.update(
where={"name": plugin_name},
data={"enabled": True, "updated_at": datetime.now(timezone.utc)},
)
@@ -516,7 +517,7 @@ async def disable_plugin(
try:
prisma_client = await _get_prisma_client()
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique(
where={"name": plugin_name}
)
if not plugin:
@@ -525,7 +526,7 @@ async def disable_plugin(
detail={"error": f"Plugin '{plugin_name}' not found"},
)
await prisma_client.db.litellm_claudecodeplugintable.update(
await ClaudeCodePluginRepository(prisma_client).table.update(
where={"name": plugin_name},
data={"enabled": False, "updated_at": datetime.now(timezone.utc)},
)
@@ -561,7 +562,7 @@ async def delete_plugin(
try:
prisma_client = await _get_prisma_client()
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique(
where={"name": plugin_name}
)
if not plugin:
@@ -570,7 +571,7 @@ async def delete_plugin(
detail={"error": f"Plugin '{plugin_name}' not found"},
)
await prisma_client.db.litellm_claudecodeplugintable.delete(
await ClaudeCodePluginRepository(prisma_client).table.delete(
where={"name": plugin_name}
)
+44 -30
View File
@@ -61,19 +61,33 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.common_utils.cache_pydantic_utils import CacheCodec
from litellm.proxy.common_utils.http_parsing_utils import (
_safe_get_request_headers,
_safe_get_request_query_params,
)
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
from litellm.proxy.guardrails.tool_name_extraction import (
TOOL_CAPABLE_CALL_TYPES,
extract_request_tool_names,
)
from litellm.proxy.common_utils.cache_pydantic_utils import CacheCodec
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
from litellm.repositories.budget_repository import BudgetRepository
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
from litellm.repositories.organization_repository import OrganizationRepository
from litellm.repositories.project_repository import ProjectRepository
from litellm.repositories.table_repositories import (
AccessGroupRepository,
EndUserRepository,
JWTKeyMappingRepository,
ManagedVectorStoresRepository,
TagRepository,
TeamMembershipRepository,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
from litellm.router import Router
from litellm.utils import get_utc_datetime
@@ -957,7 +971,7 @@ async def get_default_end_user_budget(
# Fetch from database
try:
budget_record = await prisma_client.db.litellm_budgettable.find_unique(
budget_record = await BudgetRepository(prisma_client).table.find_unique(
where={"budget_id": litellm.max_end_user_budget_id}
)
@@ -1016,7 +1030,7 @@ async def get_team_member_default_budget(
return LiteLLM_BudgetTable(**cached_budget)
try:
budget_record = await prisma_client.db.litellm_budgettable.find_unique(
budget_record = await BudgetRepository(prisma_client).table.find_unique(
where={"budget_id": budget_id}
)
@@ -1175,7 +1189,7 @@ async def get_end_user_object(
# Fetch from database
try:
response = await prisma_client.db.litellm_endusertable.find_unique(
response = await EndUserRepository(prisma_client).table.find_unique(
where={"user_id": end_user_id},
include={"litellm_budget_table": True, "object_permission": True},
)
@@ -1375,7 +1389,7 @@ async def get_tag_objects_batch(
# Batch fetch uncached tags from DB in one query
if uncached_tags:
try:
db_tags = await prisma_client.db.litellm_tagtable.find_many(
db_tags = await TagRepository(prisma_client).table.find_many(
where={"tag_name": {"in": uncached_tags}},
include={"litellm_budget_table": True},
)
@@ -1469,7 +1483,7 @@ async def get_team_membership(
# else, check db
try:
response = await prisma_client.db.litellm_teammembership.find_unique(
response = await TeamMembershipRepository(prisma_client).table.find_unique(
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}},
include={"litellm_budget_table": True},
)
@@ -1622,7 +1636,7 @@ async def _get_fuzzy_user_object(
response = None
if sso_user_id is not None:
response = await prisma_client.db.litellm_usertable.find_unique(
response = await UserRepository(prisma_client).table.find_unique(
where={"sso_user_id": sso_user_id},
include={"organization_memberships": True},
)
@@ -1630,14 +1644,14 @@ async def _get_fuzzy_user_object(
if response is None and user_email is not None:
# Use case-insensitive query to handle emails with different casing
# This matches the pattern used in _check_duplicate_user_email
response = await prisma_client.db.litellm_usertable.find_first(
response = await UserRepository(prisma_client).table.find_first(
where={"user_email": {"equals": user_email, "mode": "insensitive"}},
include={"organization_memberships": True},
)
if response is not None and sso_user_id is not None: # update sso_user_id
asyncio.create_task( # background task to update user with sso id
prisma_client.db.litellm_usertable.update(
UserRepository(prisma_client).table.update(
where={"user_id": response.user_id},
data={"sso_user_id": sso_user_id},
)
@@ -1687,7 +1701,7 @@ async def get_user_object(
)
if should_check_db:
response = await prisma_client.db.litellm_usertable.find_unique(
response = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
)
@@ -1711,7 +1725,7 @@ async def get_user_object(
if litellm.default_internal_user_params is not None:
new_user_params.update(litellm.default_internal_user_params)
response = await prisma_client.db.litellm_usertable.create(
response = await UserRepository(prisma_client).table.create(
data=new_user_params,
include={"organization_memberships": True},
)
@@ -1860,7 +1874,7 @@ async def _delete_cache_key_object(
async def _get_team_db_check(
team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None
):
response = await prisma_client.db.litellm_teamtable.find_unique(
response = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
@@ -1882,7 +1896,7 @@ async def _get_team_db_check(
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
return await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
@@ -2111,7 +2125,7 @@ async def get_access_object(
# Not in cache - fetch from DB
try:
response = await prisma_client.db.litellm_accessgrouptable.find_unique(
response = await AccessGroupRepository(prisma_client).table.find_unique(
where={"access_group_id": access_group_id}
)
@@ -2193,7 +2207,7 @@ async def get_team_object_by_alias(
# Query database by team_alias
try:
teams = await prisma_client.db.litellm_teamtable.find_many(
teams = await TeamRepository(prisma_client).table.find_many(
where={"team_alias": team_alias}
)
@@ -2301,7 +2315,7 @@ async def get_org_object_by_alias(
# Query database by organization_alias
try:
orgs = await prisma_client.db.litellm_organizationtable.find_many(
orgs = await OrganizationRepository(prisma_client).table.find_many(
where={"organization_alias": org_alias}
)
@@ -2526,7 +2540,7 @@ async def get_jwt_key_mapping_object(
Returns the hashed token (str) if a matching active mapping is found, else None.
"""
mapping = await prisma_client.db.litellm_jwtkeymapping.find_first(
mapping = await JWTKeyMappingRepository(prisma_client).table.find_first(
where={
"jwt_claim_name": jwt_claim_name,
"jwt_claim_value": jwt_claim_value,
@@ -2659,7 +2673,7 @@ async def get_object_permission(
# else, check db
try:
response = await prisma_client.db.litellm_objectpermissiontable.find_unique(
response = await ObjectPermissionRepository(prisma_client).table.find_unique(
where={"object_permission_id": object_permission_id}
)
@@ -2715,7 +2729,7 @@ async def get_managed_vector_store_rows_by_uuids(
if not cache_misses:
return result
rows = await prisma_client.db.litellm_managedvectorstorestable.find_many(
rows = await ManagedVectorStoresRepository(prisma_client).table.find_many(
where={"vector_store_id": {"in": cache_misses}},
take=len(cache_misses),
)
@@ -2790,7 +2804,7 @@ async def get_org_object(
if include_budget_table:
query_kwargs["include"] = {"litellm_budget_table": True}
response = await prisma_client.db.litellm_organizationtable.find_unique(
response = await OrganizationRepository(prisma_client).table.find_unique(
**query_kwargs
)
@@ -4180,7 +4194,7 @@ async def get_project_object(
return deserialized_project
# Fetch from DB
project_row = await prisma_client.db.litellm_projecttable.find_unique(
project_row = await ProjectRepository(prisma_client).table.find_unique(
where={"project_id": project_id},
include={"litellm_budget_table": True},
)
@@ -4480,10 +4494,10 @@ async def vector_store_access_check(
#########################################################
# Check if the key can access the vector store
if valid_token is not None and valid_token.object_permission_id is not None:
key_object_permission = (
await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": valid_token.object_permission_id},
)
key_object_permission = await ObjectPermissionRepository(
prisma_client
).table.find_unique(
where={"object_permission_id": valid_token.object_permission_id},
)
if key_object_permission is not None:
_can_object_call_vector_stores(
@@ -4494,10 +4508,10 @@ async def vector_store_access_check(
# Check if the team can access the vector store
if team_object is not None and team_object.object_permission_id is not None:
team_object_permission = (
await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": team_object.object_permission_id},
)
team_object_permission = await ObjectPermissionRepository(
prisma_client
).table.find_unique(
where={"object_permission_id": team_object.object_permission_id},
)
if team_object_permission is not None:
_can_object_call_vector_stores(
+3 -2
View File
@@ -14,11 +14,11 @@ import os
import re
from typing import Any, List, Literal, Optional, Set, Tuple, Union, cast
import jwt
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from fastapi import HTTPException, status
import jwt
from jwt.api_jwk import PyJWK
from litellm._logging import verbose_proxy_logger
@@ -50,6 +50,7 @@ from litellm.proxy.auth.auth_checks import can_team_access_model
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.repositories.user_repository import UserRepository
from .auth_checks import (
_allowed_routes_check,
@@ -1790,7 +1791,7 @@ class JWTAuthManager:
# Update user role
new_role = jwt_handler.map_jwt_role_to_litellm_role(jwt_valid_token)
if new_role and user_object.user_role != new_role.value:
await prisma_client.db.litellm_usertable.update(
await UserRepository(prisma_client).table.update(
where={"user_id": user_object.user_id},
data={"user_role": new_role.value},
)
+3 -2
View File
@@ -34,6 +34,7 @@ from litellm.proxy.utils import (
hash_password,
verify_password,
)
from litellm.repositories.user_repository import UserRepository
from litellm.secret_managers.main import get_secret_bool
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
@@ -45,7 +46,7 @@ async def _rehash_password_if_needed(user_id: str, password: str, stored: str) -
from litellm.proxy.proxy_server import prisma_client
if prisma_client is not None:
await prisma_client.db.litellm_usertable.update(
await UserRepository(prisma_client).table.update(
where={"user_id": user_id},
data={"password": hash_password(password)},
)
@@ -151,7 +152,7 @@ async def authenticate_user( # noqa: PLR0915
if prisma_client is not None:
_user_row = cast(
Optional[LiteLLM_UserTable],
await prisma_client.db.litellm_usertable.find_first(
await UserRepository(prisma_client).table.find_first(
where={"user_email": {"equals": username, "mode": "insensitive"}}
),
)
+2 -1
View File
@@ -6,6 +6,7 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
from litellm.router import Router
from litellm.router_utils.fallback_event_handlers import get_fallback_model_group
from litellm.types.router import CredentialLiteLLMParams, LiteLLM_Params
@@ -86,7 +87,7 @@ async def get_mcp_server_ids(
# Make a direct SQL query to get just the mcp_servers
try:
result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
result = await ObjectPermissionRepository(prisma_client).table.find_unique(
where={"object_permission_id": user_api_key_dict.object_permission_id},
)
if result and result.mcp_servers:
+6 -3
View File
@@ -21,9 +21,9 @@ from fastapi.security.api_key import APIKeyHeader
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm._service_logger import ServiceLogging
from litellm.constants import LITELLM_PROXY_MASTER_KEY_ALIAS
from litellm.integrations.otel.model.config import is_otel_v2_enabled
from litellm.integrations.otel.runtime import phase_span, seed_request_identity
from litellm.constants import LITELLM_PROXY_MASTER_KEY_ALIAS
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
from litellm.proxy._types import *
@@ -65,7 +65,6 @@ from litellm.proxy.auth.oauth2_check import Oauth2Handler
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.common_utils.cache_coordinator import EventDrivenCacheCoordinator
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
_safe_get_request_headers,
@@ -73,12 +72,14 @@ from litellm.proxy.common_utils.http_parsing_utils import (
populate_request_with_path_params,
)
from litellm.proxy.common_utils.realtime_utils import _realtime_request_body
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.proxy.utils import (
PrismaClient,
ProxyLogging,
normalize_route_for_root_path,
)
from litellm.repositories.table_repositories import TeamMembershipRepository
from litellm.secret_managers.main import get_secret_bool
from litellm.types.services import ServiceTypes
@@ -1797,7 +1798,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
_team_id = valid_token.team_id
if _user_id is not None and _team_id is not None:
_db_member = await prisma_client.db.litellm_teammembership.find_first(
_db_member = await TeamMembershipRepository(
prisma_client
).table.find_first(
where={
"user_id": _user_id,
"team_id": _team_id,
@@ -8,7 +8,6 @@ from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.constants import (
EXPIRED_UI_SESSION_KEY_CLEANUP_JOB_NAME,
LITELLM_EXPIRED_UI_SESSION_KEY_CLEANUP_BATCH_SIZE,
@@ -16,11 +15,15 @@ from litellm.constants import (
UI_SESSION_TOKEN_TEAM_ID,
)
from litellm.proxy._types import KeyRequest, LiteLLM_VerificationToken, UserAPIKeyAuth
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_endpoints.key_management_endpoints import (
delete_verification_tokens,
)
from litellm.proxy.utils import PrismaClient
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
class ExpiredUISessionKeyCleanupManager:
@@ -147,7 +150,7 @@ class ExpiredUISessionKeyCleanupManager:
Find expired LiteLLM dashboard session keys.
"""
now = datetime.now(timezone.utc)
return await self.prisma_client.db.litellm_verificationtoken.find_many(
return await VerificationTokenRepository(self.prisma_client).table.find_many(
where={
"team_id": UI_SESSION_TOKEN_TEAM_ID,
"expires": {"lt": now},
@@ -24,6 +24,12 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
regenerate_key_fn,
)
from litellm.proxy.utils import PrismaClient
from litellm.repositories.table_repositories import (
DeprecatedVerificationTokenRepository,
)
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
class KeyRotationManager:
@@ -124,20 +130,20 @@ class KeyRotationManager:
"""
now = datetime.now(timezone.utc)
keys_with_rotation = (
await self.prisma_client.db.litellm_verificationtoken.find_many(
where={
"auto_rotate": True, # Only keys marked for auto rotation
"OR": [
{
"key_rotation_at": None
}, # Keys that need initial rotation time setup
{
"key_rotation_at": {"lte": now}
}, # Keys where rotation time has passed
],
}
)
keys_with_rotation = await VerificationTokenRepository(
self.prisma_client
).table.find_many(
where={
"auto_rotate": True, # Only keys marked for auto rotation
"OR": [
{
"key_rotation_at": None
}, # Keys that need initial rotation time setup
{
"key_rotation_at": {"lte": now}
}, # Keys where rotation time has passed
],
}
)
return keys_with_rotation
@@ -148,9 +154,9 @@ class KeyRotationManager:
"""
try:
now = datetime.now(timezone.utc)
result = await self.prisma_client.db.litellm_deprecatedverificationtoken.delete_many(
where={"revoke_at": {"lt": now}}
)
result = await DeprecatedVerificationTokenRepository(
self.prisma_client
).table.delete_many(where={"revoke_at": {"lt": now}})
if result > 0:
verbose_proxy_logger.debug(
"Cleaned up %s expired deprecated key(s)", result
@@ -206,7 +212,7 @@ class KeyRotationManager:
# Calculate next rotation time using helper function
now = datetime.now(timezone.utc)
next_rotation_time = _calculate_key_rotation_time(key.rotation_interval)
await self.prisma_client.db.litellm_verificationtoken.update(
await VerificationTokenRepository(self.prisma_client).table.update(
where={"token": response.token_id},
data={
"rotation_count": (key.rotation_count or 0) + 1,
+17 -7
View File
@@ -14,6 +14,16 @@ from litellm.proxy._types import (
LiteLLM_VerificationToken,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.repositories.organization_repository import OrganizationRepository
from litellm.repositories.table_repositories import (
EndUserRepository,
TagRepository,
TeamMembershipRepository,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.types.services import ServiceTypes
@@ -159,7 +169,7 @@ class ResetBudgetJob:
"""
return await self._cascade_reset_spend_for_budget_link(
budgets_to_reset=budgets_to_reset,
table=self.prisma_client.db.litellm_teammembership,
table=TeamMembershipRepository(self.prisma_client).table,
counter_key_fn=lambda m: f"spend:team_member:{m.user_id}:{m.team_id}",
log_subject="team memberships",
cache_key_fn=lambda m: f"{m.team_id}_{m.user_id}",
@@ -176,7 +186,7 @@ class ResetBudgetJob:
"""
return await self._cascade_reset_spend_for_budget_link(
budgets_to_reset=budgets_to_reset,
table=self.prisma_client.db.litellm_verificationtoken,
table=VerificationTokenRepository(self.prisma_client).table,
counter_key_fn=lambda k: f"spend:key:{k.token}",
log_subject="keys",
extra_where={"budget_duration": None, "spend": {"gt": 0}},
@@ -191,7 +201,7 @@ class ResetBudgetJob:
"""
return await self._cascade_reset_spend_for_budget_link(
budgets_to_reset=budgets_to_reset,
table=self.prisma_client.db.litellm_organizationtable,
table=OrganizationRepository(self.prisma_client).table,
counter_key_fn=lambda o: f"spend:org:{o.organization_id}",
log_subject="orgs",
extra_where={"spend": {"gt": 0}},
@@ -217,7 +227,7 @@ class ResetBudgetJob:
"""
return await self._cascade_reset_spend_for_budget_link(
budgets_to_reset=budgets_to_reset,
table=self.prisma_client.db.litellm_tagtable,
table=TagRepository(self.prisma_client).table,
counter_key_fn=lambda t: f"spend:tag:{t.tag_name}",
log_subject="tags",
extra_where={"spend": {"gt": 0}},
@@ -406,7 +416,7 @@ class ResetBudgetJob:
rely on the default budget (litellm.max_end_user_budget_id) applied
in-memory during auth checks.
"""
rows = await self.prisma_client.db.litellm_endusertable.find_many(
rows = await EndUserRepository(self.prisma_client).table.find_many(
where={
"budget_id": None,
"spend": {"gt": 0},
@@ -824,7 +834,7 @@ class ResetBudgetJob:
):
changed = True
if changed:
await self.prisma_client.db.litellm_verificationtoken.update(
await VerificationTokenRepository(self.prisma_client).table.update(
where={"token": row["token"]},
data={"budget_limits": json.dumps(windows)}, # type: ignore[arg-type]
)
@@ -852,7 +862,7 @@ class ResetBudgetJob:
):
changed = True
if changed:
await self.prisma_client.db.litellm_teamtable.update(
await TeamRepository(self.prisma_client).table.update(
where={"team_id": row["team_id"]},
data={"budget_limits": json.dumps(windows)}, # type: ignore[arg-type]
)
@@ -12,6 +12,7 @@ from litellm.proxy.common_utils.resource_ownership import (
is_proxy_admin,
user_can_access_resource_owner,
)
from litellm.repositories.table_repositories import ManagedObjectRepository
from litellm.responses.utils import ResponsesAPIRequestUtils
CONTAINER_OBJECT_PURPOSE = "container"
@@ -213,7 +214,7 @@ async def record_container_owner(
)
return response
table = prisma_client.db.litellm_managedobjecttable
table = ManagedObjectRepository(prisma_client).table
existing = await table.find_unique(where={"model_object_id": model_object_id})
if existing is not None:
if getattr(existing, "file_purpose", None) != CONTAINER_OBJECT_PURPOSE:
@@ -273,7 +274,7 @@ async def _get_container_owner(
if prisma_client is None:
return None
row = await prisma_client.db.litellm_managedobjecttable.find_first(
row = await ManagedObjectRepository(prisma_client).table.find_first(
where={
"model_object_id": model_object_id,
"file_purpose": CONTAINER_OBJECT_PURPOSE,
@@ -319,7 +320,7 @@ async def _get_stored_container_id(
if prisma_client is None:
return None
row = await prisma_client.db.litellm_managedobjecttable.find_first(
row = await ManagedObjectRepository(prisma_client).table.find_first(
where={
"model_object_id": model_object_id,
"file_purpose": CONTAINER_OBJECT_PURPOSE,
@@ -411,7 +412,7 @@ async def _get_allowed_container_ids(
if prisma_client is None:
return set()
rows = await prisma_client.db.litellm_managedobjecttable.find_many(
rows = await ManagedObjectRepository(prisma_client).table.find_many(
where={
"file_purpose": CONTAINER_OBJECT_PURPOSE,
"created_by": {"in": owner_scopes},
@@ -14,6 +14,7 @@ from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
from litellm.proxy.utils import handle_exception_on_proxy, jsonify_object
from litellm.repositories.credentials_repository import CredentialsRepository
from litellm.types.utils import CreateCredentialItem, CredentialItem
router = APIRouter()
@@ -96,7 +97,7 @@ async def create_credential(
)
credentials_dict = encrypted_credential.model_dump()
credentials_dict_jsonified = jsonify_object(credentials_dict)
await prisma_client.db.litellm_credentialstable.create(
await CredentialsRepository(prisma_client).create(
data={
**credentials_dict_jsonified,
"created_by": user_api_key_dict.user_id,
@@ -245,9 +246,7 @@ async def delete_credential(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
await prisma_client.db.litellm_credentialstable.delete(
where={"credential_name": credential_name}
)
await CredentialsRepository(prisma_client).delete_by_name(credential_name)
## DELETE FROM LITELLM ##
litellm.credential_list = [
@@ -326,15 +325,14 @@ async def update_credential(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
db_credential = await prisma_client.db.litellm_credentialstable.find_unique(
where={"credential_name": credential_name},
)
credentials_repository = CredentialsRepository(prisma_client)
db_credential = await credentials_repository.find_by_name(credential_name)
if db_credential is None:
raise HTTPException(status_code=404, detail="Credential not found in DB.")
merged_credential = update_db_credential(db_credential, credential)
credential_object_jsonified = jsonify_object(merged_credential.model_dump())
await prisma_client.db.litellm_credentialstable.update(
where={"credential_name": credential_name},
await credentials_repository.update_by_name(
credential_name,
data={
**credential_object_jsonified,
"updated_by": user_api_key_dict.user_id,
+18 -8
View File
@@ -20,6 +20,16 @@ from typing import TYPE_CHECKING, ClassVar, Optional
from litellm._logging import verbose_proxy_logger
from litellm.constants import SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.repositories.organization_repository import OrganizationRepository
from litellm.repositories.table_repositories import (
SpendLogsRepository,
TeamMembershipRepository,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
if TYPE_CHECKING:
from litellm.caching.dual_cache import DualCache
@@ -83,25 +93,25 @@ class SpendCounterReseed:
try:
if counter_key.startswith("spend:key:"):
token = counter_key[len("spend:key:") :]
row = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": token}
)
row = await VerificationTokenRepository(
prisma_client
).table.find_unique(where={"token": token})
elif counter_key.startswith("spend:team_member:"):
suffix = counter_key[len("spend:team_member:") :]
if ":" not in suffix:
return None
user_id, team_id = suffix.rsplit(":", 1)
row = await prisma_client.db.litellm_teammembership.find_unique(
row = await TeamMembershipRepository(prisma_client).table.find_unique(
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}}
)
elif counter_key.startswith("spend:team:"):
team_id = counter_key[len("spend:team:") :]
row = await prisma_client.db.litellm_teamtable.find_unique(
row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
elif counter_key.startswith("spend:user:"):
user_id = counter_key[len("spend:user:") :]
row = await prisma_client.db.litellm_usertable.find_unique(
row = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_id}
)
elif counter_key.startswith("spend:end_user:"):
@@ -110,7 +120,7 @@ class SpendCounterReseed:
return None
elif counter_key.startswith("spend:org:"):
org_id = counter_key[len("spend:org:") :]
row = await prisma_client.db.litellm_organizationtable.find_unique(
row = await OrganizationRepository(prisma_client).table.find_unique(
where={"organization_id": org_id}
)
else:
@@ -243,7 +253,7 @@ class SpendCounterReseed:
return None
try:
response = await prisma_client.db.litellm_spendlogs.group_by(
response = await SpendLogsRepository(prisma_client).table.group_by(
by=[group_field],
where=where, # type: ignore[arg-type]
sum={"spend": True},
+2 -1
View File
@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Set
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
from litellm.proxy.utils import PrismaClient
from litellm.repositories.table_repositories import SpendLogToolIndexRepository
def _add_tool_calls_to_set(tool_calls: Any, out: Set[str]) -> None:
@@ -141,7 +142,7 @@ async def process_spend_logs_tool_usage(
}
)
if index_data:
await prisma_client.db.litellm_spendlogtoolindex.create_many(
await SpendLogToolIndexRepository(prisma_client).table.create_many(
data=index_data,
skip_duplicates=True,
)
+14 -12
View File
@@ -11,6 +11,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ToolDiscoveryQueueItem
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
from litellm.repositories.table_repositories import ToolRepository
from litellm.types.tool_management import (
LiteLLM_ToolTableRow,
ToolPolicyOverrideRow,
@@ -84,7 +86,7 @@ async def batch_upsert_tools(
if not data:
return
now = datetime.now(timezone.utc)
table = prisma_client.db.litellm_tooltable
table = ToolRepository(prisma_client).table
for item in data:
tool_name = item.get("tool_name", "")
origin = item.get("origin") or "user_defined"
@@ -134,7 +136,7 @@ async def list_tools(
"""Return all tools, optionally filtered by input_policy."""
try:
where = {"input_policy": input_policy} if input_policy is not None else {}
rows = await prisma_client.db.litellm_tooltable.find_many(
rows = await ToolRepository(prisma_client).table.find_many(
where=where,
order={"created_at": "desc"},
)
@@ -150,7 +152,7 @@ async def get_tool(
) -> Optional[LiteLLM_ToolTableRow]:
"""Return a single tool row by tool_name."""
try:
row = await prisma_client.db.litellm_tooltable.find_unique(
row = await ToolRepository(prisma_client).table.find_unique(
where={"tool_name": tool_name},
)
if row is None:
@@ -192,7 +194,7 @@ async def update_tool_policy(
if output_policy is not None:
update_data["output_policy"] = output_policy
await prisma_client.db.litellm_tooltable.upsert(
await ToolRepository(prisma_client).table.upsert(
where={"tool_name": tool_name},
data={
"create": create_data,
@@ -217,7 +219,7 @@ async def get_tools_by_names(
if not tool_names:
return {}
try:
rows = await prisma_client.db.litellm_tooltable.find_many(
rows = await ToolRepository(prisma_client).table.find_many(
where={"tool_name": {"in": tool_names}},
)
return {
@@ -244,7 +246,7 @@ async def list_overrides_for_tool(
"""
out: List[ToolPolicyOverrideRow] = []
try:
perms = await prisma_client.db.litellm_objectpermissiontable.find_many(
perms = await ObjectPermissionRepository(prisma_client).table.find_many(
where={"blocked_tools": {"has": tool_name}},
include={
"verification_tokens": True,
@@ -307,7 +309,7 @@ class ToolPolicyRegistry:
async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None:
"""Load all tool policies and object-permission blocked_tools from DB."""
try:
tools = await prisma_client.db.litellm_tooltable.find_many()
tools = await ToolRepository(prisma_client).table.find_many()
self._tool_input_policies = {
row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted"
for row in tools
@@ -317,7 +319,7 @@ class ToolPolicyRegistry:
for row in tools
}
perms = await prisma_client.db.litellm_objectpermissiontable.find_many()
perms = await ObjectPermissionRepository(prisma_client).table.find_many()
self._blocked_tools_by_op_id = {}
for row in perms:
op_id = getattr(row, "object_permission_id", None)
@@ -388,7 +390,7 @@ async def add_tool_to_object_permission_blocked(
if not object_permission_id or not tool_name:
return False
try:
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
row = await ObjectPermissionRepository(prisma_client).table.find_unique(
where={"object_permission_id": object_permission_id},
)
if row is None:
@@ -397,7 +399,7 @@ async def add_tool_to_object_permission_blocked(
if tool_name in current:
return True
current.append(tool_name)
await prisma_client.db.litellm_objectpermissiontable.update(
await ObjectPermissionRepository(prisma_client).table.update(
where={"object_permission_id": object_permission_id},
data={"blocked_tools": current},
)
@@ -418,7 +420,7 @@ async def remove_tool_from_object_permission_blocked(
if not object_permission_id or not tool_name:
return False
try:
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
row = await ObjectPermissionRepository(prisma_client).table.find_unique(
where={"object_permission_id": object_permission_id},
)
if row is None:
@@ -427,7 +429,7 @@ async def remove_tool_from_object_permission_blocked(
if tool_name not in current:
return False
current = [t for t in current if t != tool_name]
await prisma_client.db.litellm_objectpermissiontable.update(
await ObjectPermissionRepository(prisma_client).table.update(
where={"object_permission_id": object_permission_id},
data={"blocked_tools": current},
)
+13 -13
View File
@@ -13,21 +13,21 @@ from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from litellm.proxy.common_utils.path_utils import safe_join
from litellm._logging import verbose_proxy_logger
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
from litellm.proxy.common_utils.path_utils import safe_join
from litellm.proxy.guardrails.guardrail_hooks.custom_code.sandbox import (
build_sandbox_globals,
compile_sandboxed,
)
from litellm.proxy.guardrails.guardrail_registry import GuardrailRegistry
from litellm.proxy.guardrails.usage_endpoints import router as guardrails_usage_router
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
from litellm.repositories.table_repositories import GuardrailsRepository
from litellm.types.guardrails import (
PII_ENTITY_CATEGORIES_MAP,
ApplyGuardrailRequest,
@@ -373,7 +373,7 @@ async def create_guardrail(
# Configuration error — roll back the DB write so the guardrail isn't orphaned
if prisma_client is not None:
try:
await prisma_client.db.litellm_guardrailstable.delete(
await GuardrailsRepository(prisma_client).table.delete(
where={"guardrail_id": guardrail_id}
)
except Exception as rollback_err:
@@ -705,7 +705,7 @@ async def register_guardrail(
)
try:
existing = await prisma_client.db.litellm_guardrailstable.find_unique(
existing = await GuardrailsRepository(prisma_client).table.find_unique(
where={"guardrail_name": request.guardrail_name}
)
if existing is not None:
@@ -732,7 +732,7 @@ async def register_guardrail(
guardrail_info_str = safe_dumps(guardrail_info)
try:
created = await prisma_client.db.litellm_guardrailstable.create(
created = await GuardrailsRepository(prisma_client).table.create(
data={
"guardrail_name": request.guardrail_name,
"litellm_params": litellm_params_str,
@@ -874,7 +874,7 @@ async def list_guardrail_submissions(
where_clause["team_id"] = {"in": visible_team_ids}
# Single query: fetch team guardrails visible to the caller
all_team_rows = await prisma_client.db.litellm_guardrailstable.find_many(
all_team_rows = await GuardrailsRepository(prisma_client).table.find_many(
where=where_clause,
order={"created_at": "desc"},
)
@@ -945,7 +945,7 @@ async def get_guardrail_submission(
is_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
try:
row = await prisma_client.db.litellm_guardrailstable.find_unique(
row = await GuardrailsRepository(prisma_client).table.find_unique(
where={"guardrail_id": guardrail_id}
)
if row is None:
@@ -986,7 +986,7 @@ async def approve_guardrail_submission(
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
row = await prisma_client.db.litellm_guardrailstable.find_unique(
row = await GuardrailsRepository(prisma_client).table.find_unique(
where={"guardrail_id": guardrail_id}
)
if row is None:
@@ -1000,7 +1000,7 @@ async def approve_guardrail_submission(
)
now = datetime.now(timezone.utc)
await prisma_client.db.litellm_guardrailstable.update(
await GuardrailsRepository(prisma_client).table.update(
where={"guardrail_id": guardrail_id},
data={"status": "active", "reviewed_at": now, "updated_at": now},
)
@@ -1072,7 +1072,7 @@ async def reject_guardrail_submission(
raise HTTPException(status_code=500, detail="Prisma client not initialized")
try:
row = await prisma_client.db.litellm_guardrailstable.find_unique(
row = await GuardrailsRepository(prisma_client).table.find_unique(
where={"guardrail_id": guardrail_id}
)
if row is None:
@@ -1086,7 +1086,7 @@ async def reject_guardrail_submission(
)
now = datetime.now(timezone.utc)
await prisma_client.db.litellm_guardrailstable.update(
await GuardrailsRepository(prisma_client).table.update(
where={"guardrail_id": guardrail_id},
data={"status": "rejected", "reviewed_at": now, "updated_at": now},
)
@@ -2288,10 +2288,10 @@ async def apply_guardrail(
"""
import traceback
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.litellm_core_utils.thread_pool_executor import (
executor as thread_pool_executor,
)
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.proxy_server import (
general_settings,
proxy_config,
+17 -14
View File
@@ -11,12 +11,15 @@ from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy.guardrails.guardrail_hooks.grayswan import GraySwanGuardrail
from litellm.proxy.guardrails.guardrail_hooks.grayswan import (
GraySwanGuardrail,
)
from litellm.proxy.guardrails.guardrail_hooks.grayswan import (
initialize_guardrail as initialize_grayswan,
)
from litellm.proxy.types_utils.utils import get_instance_fn
from litellm.proxy.utils import PrismaClient
from litellm.repositories.table_repositories import GuardrailsRepository
from litellm.secret_managers.main import get_secret
from litellm.types.guardrails import (
Guardrail,
@@ -26,6 +29,9 @@ from litellm.types.guardrails import (
SupportedGuardrailIntegrations,
)
from .guardrail_hooks.llm_as_a_judge import (
initialize_guardrail as initialize_llm_as_a_judge,
)
from .guardrail_initializers import (
initialize_bedrock,
initialize_hide_secrets,
@@ -34,9 +40,6 @@ from .guardrail_initializers import (
initialize_presidio,
initialize_tool_permission,
)
from .guardrail_hooks.llm_as_a_judge import (
initialize_guardrail as initialize_llm_as_a_judge,
)
guardrail_initializer_registry = {
SupportedGuardrailIntegrations.BEDROCK.value: initialize_bedrock,
@@ -257,7 +260,7 @@ class GuardrailRegistry:
guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {}))
# Create guardrail in DB
created_guardrail = await prisma_client.db.litellm_guardrailstable.create(
created_guardrail = await GuardrailsRepository(prisma_client).table.create(
data={
"guardrail_name": guardrail_name,
"litellm_params": litellm_params,
@@ -283,7 +286,7 @@ class GuardrailRegistry:
"""
try:
# Delete from DB
await prisma_client.db.litellm_guardrailstable.delete(
await GuardrailsRepository(prisma_client).table.delete(
where={"guardrail_id": guardrail_id}
)
@@ -311,7 +314,7 @@ class GuardrailRegistry:
guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {}))
# Update in DB
updated_guardrail = await prisma_client.db.litellm_guardrailstable.update(
updated_guardrail = await GuardrailsRepository(prisma_client).table.update(
where={"guardrail_id": guardrail_id},
data={
"guardrail_name": guardrail_name,
@@ -335,11 +338,11 @@ class GuardrailRegistry:
Only rows with status == "active" are returned (pending_review and rejected are excluded).
"""
try:
guardrails_from_db = (
await prisma_client.db.litellm_guardrailstable.find_many(
where={"status": "active"},
order={"created_at": "desc"},
)
guardrails_from_db = await GuardrailsRepository(
prisma_client
).table.find_many(
where={"status": "active"},
order={"created_at": "desc"},
)
guardrails: List[Guardrail] = []
@@ -357,7 +360,7 @@ class GuardrailRegistry:
Get a guardrail by its ID from the database
"""
try:
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
guardrail = await GuardrailsRepository(prisma_client).table.find_unique(
where={"guardrail_id": guardrail_id}
)
@@ -375,7 +378,7 @@ class GuardrailRegistry:
Get a guardrail by its name from the database
"""
try:
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
guardrail = await GuardrailsRepository(prisma_client).table.find_unique(
where={"guardrail_name": guardrail_name}
)
+29 -15
View File
@@ -12,6 +12,14 @@ from pydantic import BaseModel
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.repositories.table_repositories import (
DailyGuardrailMetricsRepository,
DailyPolicyMetricsRepository,
GuardrailsRepository,
PolicyRepository,
SpendLogGuardrailIndexRepository,
SpendLogsRepository,
)
router = APIRouter()
@@ -272,10 +280,10 @@ async def guardrails_usage_overview(
try:
# Guardrails from DB
guardrails = await prisma_client.db.litellm_guardrailstable.find_many()
guardrails = await GuardrailsRepository(prisma_client).table.find_many()
# Daily metrics in range
metrics = await prisma_client.db.litellm_dailyguardrailmetrics.find_many(
metrics = await DailyGuardrailMetricsRepository(prisma_client).table.find_many(
where={"date": {"gte": start, "lte": end}}
)
@@ -283,9 +291,9 @@ async def guardrails_usage_overview(
start_prev = (
datetime.strptime(start, "%Y-%m-%d") - timedelta(days=7)
).strftime("%Y-%m-%d")
metrics_prev = await prisma_client.db.litellm_dailyguardrailmetrics.find_many(
where={"date": {"gte": start_prev, "lt": start}}
)
metrics_prev = await DailyGuardrailMetricsRepository(
prisma_client
).table.find_many(where={"date": {"gte": start_prev, "lt": start}})
agg = _aggregate_daily_metrics(metrics, "guardrail_id")
prev_agg = _prev_fail_rates(metrics_prev, "guardrail_id")
@@ -335,7 +343,7 @@ async def guardrails_usage_detail(
end = end_date or now.strftime("%Y-%m-%d")
start = start_date or (now - timedelta(days=7)).strftime("%Y-%m-%d")
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
guardrail = await GuardrailsRepository(prisma_client).table.find_unique(
where={"guardrail_id": guardrail_id}
)
if not guardrail:
@@ -349,13 +357,13 @@ async def guardrails_usage_detail(
)
metric_ids = [i for i in (logical_id, guardrail_id) if i]
metrics = await prisma_client.db.litellm_dailyguardrailmetrics.find_many(
metrics = await DailyGuardrailMetricsRepository(prisma_client).table.find_many(
where={
"guardrail_id": {"in": metric_ids},
"date": {"gte": start, "lte": end},
}
)
metrics_prev = await prisma_client.db.litellm_dailyguardrailmetrics.find_many(
metrics_prev = await DailyGuardrailMetricsRepository(prisma_client).table.find_many(
where={
"guardrail_id": {"in": metric_ids},
"date": {"lt": start},
@@ -574,7 +582,7 @@ async def guardrails_usage_logs(
# Query by both so we match regardless of which was written.
effective_guardrail_ids: List[str] = [guardrail_id] if guardrail_id else []
if guardrail_id:
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
guardrail = await GuardrailsRepository(prisma_client).table.find_unique(
where={"guardrail_id": guardrail_id}
)
if guardrail:
@@ -585,19 +593,23 @@ async def guardrails_usage_logs(
where = _build_usage_logs_where(
effective_guardrail_ids or None, policy_id, start_date, end_date
)
index_rows = await prisma_client.db.litellm_spendlogguardrailindex.find_many(
index_rows = await SpendLogGuardrailIndexRepository(
prisma_client
).table.find_many(
where=where,
order={"start_time": "desc"},
skip=(page - 1) * page_size,
take=page_size + 1,
)
total = await prisma_client.db.litellm_spendlogguardrailindex.count(where=where)
total = await SpendLogGuardrailIndexRepository(prisma_client).table.count(
where=where
)
request_ids = [r.request_id for r in index_rows[:page_size]]
if not request_ids:
return UsageLogsResponse(
logs=[], total=total, page=page, page_size=page_size
)
spend_logs = await prisma_client.db.litellm_spendlogs.find_many(
spend_logs = await SpendLogsRepository(prisma_client).table.find_many(
where={"request_id": {"in": request_ids}}
)
log_by_id = {s.request_id: s for s in spend_logs}
@@ -645,11 +657,13 @@ async def policies_usage_overview(
start = start_date or (now - timedelta(days=7)).strftime("%Y-%m-%d")
try:
policies = await prisma_client.db.litellm_policytable.find_many()
metrics = await prisma_client.db.litellm_dailypolicymetrics.find_many(
policies = await PolicyRepository(prisma_client).table.find_many()
metrics = await DailyPolicyMetricsRepository(prisma_client).table.find_many(
where={"date": {"gte": start, "lte": end}}
)
metrics_prev = await prisma_client.db.litellm_dailypolicymetrics.find_many(
metrics_prev = await DailyPolicyMetricsRepository(
prisma_client
).table.find_many(
where={
"date": {
"gte": (
+6 -2
View File
@@ -10,6 +10,10 @@ from typing import Any, Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy.utils import PrismaClient
from litellm.repositories.table_repositories import (
DailyGuardrailMetricsRepository,
SpendLogGuardrailIndexRepository,
)
def _guardrail_status_to_action(status: Optional[str]) -> str:
@@ -132,7 +136,7 @@ async def process_spend_logs_guardrail_usage(
}
)
try:
await prisma_client.db.litellm_spendlogguardrailindex.create_many(
await SpendLogGuardrailIndexRepository(prisma_client).table.create_many(
data=index_data,
skip_duplicates=True,
)
@@ -146,7 +150,7 @@ async def process_spend_logs_guardrail_usage(
n = int(agg["requests_evaluated"])
if n == 0:
continue
await prisma_client.db.litellm_dailyguardrailmetrics.upsert(
await DailyGuardrailMetricsRepository(prisma_client).table.upsert(
where={
"guardrail_id_date": {
"guardrail_id": guardrail_id,
@@ -3,7 +3,6 @@ Hooks that are triggered when a litellm user event occurs
"""
import asyncio
from litellm._uuid import uuid
from datetime import datetime, timezone
from typing import Optional
@@ -11,6 +10,7 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.proxy._types import (
AUDIT_ACTIONS,
CommonProxyErrors,
@@ -24,6 +24,7 @@ from litellm.proxy._types import (
WebhookEvent,
)
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
from litellm.repositories.user_repository import UserRepository
class UserManagementEventHooks:
@@ -57,7 +58,7 @@ class UserManagementEventHooks:
try:
if prisma_client is None:
raise Exception(CommonProxyErrors.db_not_connected_error.value)
user_row: BaseModel = await prisma_client.db.litellm_usertable.find_first(
user_row: BaseModel = await UserRepository(prisma_client).table.find_first(
where={"user_id": response.user_id}
)
@@ -19,6 +19,7 @@ from litellm.proxy.auth.auth_checks import (
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
from litellm.proxy.utils import get_prisma_client_or_throw
from litellm.repositories.table_repositories import AccessGroupRepository
from litellm.types.access_group import (
AccessGroupCreateRequest,
AccessGroupResponse,
@@ -386,7 +387,7 @@ async def list_access_groups(
CommonProxyErrors.db_not_connected_error.value
)
records = await prisma_client.db.litellm_accessgrouptable.find_many(
records = await AccessGroupRepository(prisma_client).table.find_many(
order={"created_at": "desc"}
)
return [_record_to_response(r) for r in records]
@@ -405,7 +406,7 @@ async def get_access_group(
CommonProxyErrors.db_not_connected_error.value
)
record = await prisma_client.db.litellm_accessgrouptable.find_unique(
record = await AccessGroupRepository(prisma_client).table.find_unique(
where={"access_group_id": access_group_id}
)
if record is None:
@@ -16,11 +16,12 @@ import math
from fastapi import APIRouter, Depends, HTTPException
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
from litellm.proxy.utils import jsonify_object
from litellm.repositories.budget_repository import BudgetRepository
router = APIRouter()
@@ -98,7 +99,7 @@ async def new_budget(
budget_obj_json = budget_obj.model_dump(exclude_none=True)
budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries
try:
response = await prisma_client.db.litellm_budgettable.create(
response = await BudgetRepository(prisma_client).table.create(
data={
**budget_obj_jsonified, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
@@ -182,7 +183,7 @@ async def update_budget(
except ValueError as e:
raise HTTPException(status_code=400, detail={"error": str(e)})
response = await prisma_client.db.litellm_budgettable.update(
response = await BudgetRepository(prisma_client).table.update(
where={"budget_id": budget_obj.budget_id},
data={
**budget_obj.model_dump(exclude_unset=True), # type: ignore
@@ -217,7 +218,7 @@ async def info_budget(data: BudgetRequest):
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
},
)
response = await prisma_client.db.litellm_budgettable.find_many(
response = await BudgetRepository(prisma_client).table.find_many(
where={"budget_id": {"in": data.budgets}},
)
@@ -261,7 +262,7 @@ async def budget_settings(
)
## get budget item from db
db_budget_row = await prisma_client.db.litellm_budgettable.find_first(
db_budget_row = await BudgetRepository(prisma_client).table.find_first(
where={"budget_id": budget_id}
)
@@ -327,7 +328,7 @@ async def list_budget(
},
)
response = await prisma_client.db.litellm_budgettable.find_many()
response = await BudgetRepository(prisma_client).table.find_many()
return response
@@ -366,7 +367,7 @@ async def delete_budget(
},
)
response = await prisma_client.db.litellm_budgettable.delete(
response = await BudgetRepository(prisma_client).table.delete(
where={"budget_id": data.id}
)
@@ -27,6 +27,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.repositories.table_repositories import CacheConfigRepository
from litellm.types.management_endpoints import (
CACHE_SETTINGS_FIELDS,
REDIS_TYPE_DESCRIPTIONS,
@@ -159,7 +160,7 @@ class CacheSettingsManager:
import json
try:
cache_config = await prisma_client.db.litellm_cacheconfig.find_unique(
cache_config = await CacheConfigRepository(prisma_client).table.find_unique(
where={"id": "cache_config"}
)
if cache_config is not None and cache_config.cache_settings:
@@ -274,7 +275,7 @@ async def get_cache_settings(
# Try to get cache settings from database
current_values = {}
if prisma_client is not None:
cache_config = await prisma_client.db.litellm_cacheconfig.find_unique(
cache_config = await CacheConfigRepository(prisma_client).table.find_unique(
where={"id": "cache_config"}
)
if cache_config is not None and cache_config.cache_settings:
@@ -417,7 +418,7 @@ async def update_cache_settings(
# Snapshot the prior settings (key set only — values get redacted in
# the audit row) so the audit-log entry shows which fields changed.
existing_row = await prisma_client.db.litellm_cacheconfig.find_unique(
existing_row = await CacheConfigRepository(prisma_client).table.find_unique(
where={"id": "cache_config"}
)
before_settings: Optional[Dict[str, Any]] = None
@@ -434,7 +435,7 @@ async def update_cache_settings(
)
# Save to database
await prisma_client.db.litellm_cacheconfig.upsert(
await CacheConfigRepository(prisma_client).table.upsert(
where={"id": "cache_config"},
data={
"create": {
@@ -8,6 +8,10 @@ from fastapi import HTTPException, status
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.utils import PrismaClient
from litellm.repositories.table_repositories import DeletedVerificationTokenRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.types.proxy.management_endpoints.common_daily_activity import (
BreakdownMetrics,
DailySpendData,
@@ -346,7 +350,7 @@ async def get_api_key_metadata(
This ensures that key_alias and team_id are preserved in historical activity logs
even after a key is deleted or regenerated.
"""
key_records = await prisma_client.db.litellm_verificationtoken.find_many(
key_records = await VerificationTokenRepository(prisma_client).table.find_many(
where={"token": {"in": list(api_keys)}}
)
result = {
@@ -357,11 +361,11 @@ async def get_api_key_metadata(
missing_keys = api_keys - set(result.keys())
if missing_keys:
try:
deleted_key_records = (
await prisma_client.db.litellm_deletedverificationtoken.find_many(
where={"token": {"in": list(missing_keys)}},
order={"deleted_at": "desc"},
)
deleted_key_records = await DeletedVerificationTokenRepository(
prisma_client
).table.find_many(
where={"token": {"in": list(missing_keys)}},
order={"deleted_at": "desc"},
)
# Use the most recent deleted record for each token (ordered by deleted_at desc)
for k in deleted_key_records:
@@ -17,10 +17,13 @@ from litellm.proxy._types import (
NewProjectRequest,
UpdateProjectRequest,
UserAPIKeyAuth,
user_api_key_has_admin_view as _user_has_admin_view, # noqa: F401 re-exported
)
from litellm.proxy._types import ( # noqa: F401 re-exported
user_api_key_has_admin_view as _user_has_admin_view,
)
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
from litellm.proxy.utils import _premium_user_check
from litellm.repositories.team_repository import TeamRepository
if TYPE_CHECKING:
from litellm.proxy._types import NewProjectRequest, UpdateProjectRequest
@@ -205,7 +208,7 @@ async def _user_has_admin_privileges(
# Check if user is team admin for any team
if user_obj.teams is not None and len(user_obj.teams) > 0:
# Get all teams user is in
teams = await prisma_client.db.litellm_teamtable.find_many(
teams = await TeamRepository(prisma_client).table.find_many(
where={"team_id": {"in": user_obj.teams}}
)
@@ -282,7 +285,7 @@ async def _team_admin_can_invite_user(
if not target_user_obj.teams or len(target_user_obj.teams) == 0:
return False
teams = await prisma_client.db.litellm_teamtable.find_many(
teams = await TeamRepository(prisma_client).table.find_many(
where={"team_id": {"in": admin_user_obj.teams}}
)
admin_team_ids = [
@@ -30,6 +30,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.repositories.table_repositories import ConfigOverridesRepository
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.proxy.management_endpoints.config_overrides import (
ConfigOverrideSettingsResponse,
@@ -254,7 +255,7 @@ async def update_hashicorp_vault_config(
# Merge ALL fields the user didn't send: try DB first, fall back to env vars.
# Omitted field = keep existing; empty string = clear/remove the field.
existing_record = await prisma_client.db.litellm_configoverrides.find_unique(
existing_record = await ConfigOverridesRepository(prisma_client).table.find_unique(
where={"config_type": "hashicorp_vault"}
)
existing_decrypted: Optional[Dict[str, Any]] = None
@@ -321,7 +322,7 @@ async def update_hashicorp_vault_config(
# Only persist to DB after successful init
encrypted_data = proxy_config._encrypt_env_variables(config_data)
config_value = safe_dumps(encrypted_data)
await prisma_client.db.litellm_configoverrides.upsert(
await ConfigOverridesRepository(prisma_client).table.upsert(
where={"config_type": "hashicorp_vault"},
data={
"create": {
@@ -391,7 +392,7 @@ async def get_hashicorp_vault_config(
field_schema = _build_field_schema(HashicorpVaultConfig)
# Try to load from DB
db_record = await prisma_client.db.litellm_configoverrides.find_unique(
db_record = await ConfigOverridesRepository(prisma_client).table.find_unique(
where={"config_type": "hashicorp_vault"}
)
@@ -448,7 +449,7 @@ async def delete_hashicorp_vault_config(
# Capture the prior config before delete so the audit-log row can
# show *what* was removed (keys only — values get redacted).
existing_record = await prisma_client.db.litellm_configoverrides.find_unique(
existing_record = await ConfigOverridesRepository(prisma_client).table.find_unique(
where={"config_type": "hashicorp_vault"}
)
before_config: Optional[Dict[str, Any]] = None
@@ -463,7 +464,7 @@ async def delete_hashicorp_vault_config(
# Delete DB record if it exists — ignore if not found
deleted = False
try:
await prisma_client.db.litellm_configoverrides.delete(
await ConfigOverridesRepository(prisma_client).table.delete(
where={"config_type": "hashicorp_vault"}
)
deleted = True
@@ -17,8 +17,8 @@ import fastapi
from fastapi import APIRouter, Depends, HTTPException, Request
import litellm
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
@@ -27,6 +27,8 @@ from litellm.proxy.management_helpers.object_permission_utils import (
handle_update_object_permission_common,
)
from litellm.proxy.utils import handle_exception_on_proxy
from litellm.repositories.budget_repository import BudgetRepository
from litellm.repositories.table_repositories import EndUserRepository
from litellm.types.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
)
@@ -68,7 +70,7 @@ async def block_user(data: BlockUsers):
records = []
if prisma_client is not None:
for id in data.user_ids:
record = await prisma_client.db.litellm_endusertable.upsert(
record = await EndUserRepository(prisma_client).table.upsert(
where={"user_id": id}, # type: ignore
data={
"create": {"user_id": id, "blocked": True}, # type: ignore
@@ -337,7 +339,7 @@ async def new_end_user(
_new_budget = new_budget_request(data)
if _new_budget is not None:
try:
budget_record = await prisma_client.db.litellm_budgettable.create(
budget_record = await BudgetRepository(prisma_client).table.create(
data={
**_new_budget.model_dump(exclude_unset=True),
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
@@ -373,7 +375,7 @@ async def new_end_user(
new_end_user_obj.pop("object_permission", None)
## WRITE TO DB ##
end_user_record = await prisma_client.db.litellm_endusertable.create(
end_user_record = await EndUserRepository(prisma_client).table.create(
data=new_end_user_obj, # type: ignore
include={"litellm_budget_table": True, "object_permission": True},
)
@@ -446,7 +448,7 @@ async def end_user_info(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
user_info = await prisma_client.db.litellm_endusertable.find_first(
user_info = await EndUserRepository(prisma_client).table.find_first(
where={"user_id": end_user_id},
include={"litellm_budget_table": True, "object_permission": True},
)
@@ -569,7 +571,7 @@ async def update_end_user(
non_default_values[k] = v
## Get end user table data ##
end_user_table_data = await prisma_client.db.litellm_endusertable.find_first(
end_user_table_data = await EndUserRepository(prisma_client).table.find_first(
where={"user_id": data.user_id}, include={"litellm_budget_table": True}
)
@@ -613,17 +615,17 @@ async def update_end_user(
if budget_table_data:
if end_user_budget_table is None:
## Create new budget ##
budget_table_data_record = (
await prisma_client.db.litellm_budgettable.create(
data={
**budget_table_data,
"created_by": user_api_key_dict.user_id
or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id
or litellm_proxy_admin_name,
},
include={"end_users": True},
)
budget_table_data_record = await BudgetRepository(
prisma_client
).table.create(
data={
**budget_table_data,
"created_by": user_api_key_dict.user_id
or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id
or litellm_proxy_admin_name,
},
include={"end_users": True},
)
update_end_user_table_data["budget_id"] = (
@@ -631,11 +633,11 @@ async def update_end_user(
)
else:
## Update existing budget ##
budget_table_data_record = (
await prisma_client.db.litellm_budgettable.update(
where={"budget_id": end_user_budget_table.budget_id},
data=budget_table_data,
)
budget_table_data_record = await BudgetRepository(
prisma_client
).table.update(
where={"budget_id": end_user_budget_table.budget_id},
data=budget_table_data,
)
## Update user table, with update params + new budget id (if set) ##
@@ -652,7 +654,7 @@ async def update_end_user(
if data.user_id is not None and len(data.user_id) > 0:
update_end_user_table_data["user_id"] = data.user_id # type: ignore
verbose_proxy_logger.debug("In update customer, user_id condition block.")
response = await prisma_client.db.litellm_endusertable.update(
response = await EndUserRepository(prisma_client).table.update(
where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True, "object_permission": True} # type: ignore
)
if response is None:
@@ -737,7 +739,7 @@ async def delete_end_user(
and len(data.user_ids) > 0
):
# First check if all users exist
existing_users = await prisma_client.db.litellm_endusertable.find_many(
existing_users = await EndUserRepository(prisma_client).table.find_many(
where={"user_id": {"in": data.user_ids}}
)
existing_user_ids = {user.user_id for user in existing_users}
@@ -756,7 +758,7 @@ async def delete_end_user(
)
# All users exist, proceed with deletion
response = await prisma_client.db.litellm_endusertable.delete_many(
response = await EndUserRepository(prisma_client).table.delete_many(
where={"user_id": {"in": data.user_ids}}
)
verbose_proxy_logger.debug(
@@ -828,7 +830,7 @@ async def list_end_user(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
response = await prisma_client.db.litellm_endusertable.find_many(
response = await EndUserRepository(prisma_client).table.find_many(
include={"litellm_budget_table": True, "object_permission": True}
)
@@ -903,7 +905,7 @@ async def get_customer_daily_activity(
where_condition = {}
if end_user_ids_list:
where_condition["user_id"] = {"in": list(end_user_ids_list)}
end_user_aliases = await prisma_client.db.litellm_endusertable.find_many(
end_user_aliases = await EndUserRepository(prisma_client).table.find_many(
where=where_condition
)
end_user_alias_metadata = {e.user_id: {"alias": e.alias} for e in end_user_aliases}
@@ -27,6 +27,7 @@ else:
# fastapi is only required for proxy, not for SDK usage
pass
from litellm.repositories.config_repository import ConfigRepository
from litellm.types.management_endpoints.router_settings_endpoints import (
FallbackCreateRequest,
FallbackDeleteResponse,
@@ -157,7 +158,7 @@ async def create_fallback(
# Save to database - convert router_settings to JSON string
router_settings_json = json.dumps(router_settings)
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "router_settings"},
data={
"create": {
@@ -336,7 +337,7 @@ async def delete_fallback(
# Save to database - convert router_settings to JSON string
router_settings_json = json.dumps(router_settings)
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "router_settings"},
data={
"create": {
@@ -43,6 +43,17 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
)
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import handle_exception_on_proxy, hash_password
from litellm.repositories.organization_repository import OrganizationRepository
from litellm.repositories.table_repositories import (
InvitationLinkRepository,
OrganizationMembershipRepository,
TeamMembershipRepository,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.types.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
)
@@ -154,7 +165,7 @@ async def _check_duplicate_user_field(
if case_insensitive:
where_clause[field_name]["mode"] = "insensitive"
existing_user = await prisma_client.db.litellm_usertable.find_first(
existing_user = await UserRepository(prisma_client).table.find_first(
where=where_clause
)
@@ -434,7 +445,7 @@ async def new_user(
await _check_duplicate_user_email(data.user_email, prisma_client)
# Check if license is over limit
total_users = await prisma_client.db.litellm_usertable.count()
total_users = await UserRepository(prisma_client).table.count()
if total_users and _license_check.is_over_limit(total_users=total_users):
raise HTTPException(
status_code=403,
@@ -851,7 +862,7 @@ async def _check_user_info_v2_access(
# Helper: fetch the target user row (reused across branches)
async def _fetch_target_user():
return await prisma_client.db.litellm_usertable.find_unique(
return await UserRepository(prisma_client).table.find_unique(
where={"user_id": target_user_id}
)
@@ -866,7 +877,7 @@ async def _check_user_info_v2_access(
# Rule 3: Team admins can look up users in their teams
if user_api_key_dict.user_id is not None:
# Get caller's teams
caller_user = await prisma_client.db.litellm_usertable.find_unique(
caller_user = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_api_key_dict.user_id}
)
if caller_user is not None and caller_user.teams:
@@ -876,7 +887,7 @@ async def _check_user_info_v2_access(
return None
# Get all teams the caller belongs to
teams = await prisma_client.db.litellm_teamtable.find_many(
teams = await TeamRepository(prisma_client).table.find_many(
where={"team_id": {"in": caller_user.teams}}
)
for team in teams:
@@ -1165,7 +1176,7 @@ async def _schedule_user_update_audit_log(
if prisma_client is None:
return
try:
updated_user_row = await prisma_client.db.litellm_usertable.find_first(
updated_user_row = await UserRepository(prisma_client).table.find_first(
where={"user_id": response["user_id"]}
)
if updated_user_row:
@@ -1255,11 +1266,11 @@ async def _update_single_user_helper(
existing_user_row: Optional[BaseModel] = None
if user_request.user_id:
existing_user_row = await prisma_client.db.litellm_usertable.find_first(
existing_user_row = await UserRepository(prisma_client).table.find_first(
where={"user_id": user_request.user_id}
)
elif user_request.user_email:
existing_user_row = await prisma_client.db.litellm_usertable.find_first(
existing_user_row = await UserRepository(prisma_client).table.find_first(
where={"user_email": user_request.user_email}
)
@@ -1640,7 +1651,7 @@ async def bulk_user_update(
detail="Only proxy admins can update all users at once.",
)
# Optimized path for updating all users directly in database
all_users_in_db = await prisma_client.db.litellm_usertable.find_many(
all_users_in_db = await UserRepository(prisma_client).table.find_many(
order={"created_at": "desc"}
)
@@ -1676,7 +1687,7 @@ async def bulk_user_update(
try:
# Perform bulk database update
await prisma_client.db.litellm_usertable.update_many(
await UserRepository(prisma_client).table.update_many(
where={}, data=non_default_values # Update all users
)
@@ -1783,7 +1794,7 @@ async def get_user_key_counts(
# Get count for each user_id individually
for user_id in user_ids:
count = await prisma_client.db.litellm_verificationtoken.count(
count = await VerificationTokenRepository(prisma_client).table.count(
where={
"user_id": user_id,
"OR": [
@@ -2056,7 +2067,7 @@ async def get_users(
else None
)
users = await prisma_client.db.litellm_usertable.find_many(
users = await UserRepository(prisma_client).table.find_many(
where=where_conditions,
skip=skip,
take=page_size,
@@ -2066,7 +2077,9 @@ async def get_users(
)
# Get total count of user rows
total_count = await prisma_client.db.litellm_usertable.count(where=where_conditions)
total_count = await UserRepository(prisma_client).table.count(
where=where_conditions
)
# Get key count for each user
if users is not None:
@@ -2137,14 +2150,14 @@ async def delete_user(
from litellm.proxy.management_endpoints.team_endpoints import (
_cleanup_members_with_roles,
)
from litellm.proxy.management_helpers.audit_logs import (
get_audit_log_changed_by,
)
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
litellm_proxy_admin_name,
prisma_client,
)
from litellm.proxy.management_helpers.audit_logs import (
get_audit_log_changed_by,
)
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
@@ -2164,7 +2177,7 @@ async def delete_user(
caller_admin_org_ids: set = set()
if not caller_is_proxy_admin:
caller_memberships = (
await prisma_client.db.litellm_organizationmembership.find_many(
await OrganizationMembershipRepository(prisma_client).table.find_many(
where={
"user_id": user_api_key_dict.user_id,
"user_role": LitellmUserRoles.ORG_ADMIN.value,
@@ -2188,11 +2201,9 @@ async def delete_user(
# an N+1 DB call when delete_user is called with a large user_ids list.
target_org_ids_by_user: Dict[str, set] = {}
if not caller_is_proxy_admin:
all_target_memberships = (
await prisma_client.db.litellm_organizationmembership.find_many(
where={"user_id": {"in": data.user_ids}}
)
)
all_target_memberships = await OrganizationMembershipRepository(
prisma_client
).table.find_many(where={"user_id": {"in": data.user_ids}})
for m in all_target_memberships:
if not m.organization_id:
continue
@@ -2200,7 +2211,7 @@ async def delete_user(
# check that all teams passed exist
for user_id in data.user_ids:
user_row = await prisma_client.db.litellm_usertable.find_unique(
user_row = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_id}
)
@@ -2254,7 +2265,7 @@ async def delete_user(
)
## CLEANUP MEMBERS_WITH_ROLES
fetch_all_teams = await prisma_client.db.litellm_teamtable.find_many(
fetch_all_teams = await TeamRepository(prisma_client).table.find_many(
where={"team_id": {"in": user_row.teams}}
)
teams_to_update = []
@@ -2277,19 +2288,19 @@ async def delete_user(
## update teams
for team in teams_to_update:
await prisma_client.db.litellm_teamtable.update(
await TeamRepository(prisma_client).table.update(
where={"team_id": team.team_id},
data={"members_with_roles": team.members_with_roles},
)
# End of Audit logging
## DELETE ASSOCIATED KEYS
await prisma_client.db.litellm_verificationtoken.delete_many(
await VerificationTokenRepository(prisma_client).table.delete_many(
where={"user_id": {"in": data.user_ids}}
)
## DELETE ASSOCIATED INVITATION LINKS
await prisma_client.db.litellm_invitationlink.delete_many(
await InvitationLinkRepository(prisma_client).table.delete_many(
where={
"OR": [
{"user_id": {"in": data.user_ids}},
@@ -2300,17 +2311,17 @@ async def delete_user(
)
## DELETE ASSOCIATED ORGANIZATION MEMBERSHIPS
await prisma_client.db.litellm_organizationmembership.delete_many(
await OrganizationMembershipRepository(prisma_client).table.delete_many(
where={"user_id": {"in": data.user_ids}}
)
## DELETE ASSOCIATED TEAM MEMBERSHIPS
await prisma_client.db.litellm_teammembership.delete_many(
await TeamMembershipRepository(prisma_client).table.delete_many(
where={"user_id": {"in": data.user_ids}}
)
## DELETE USERS
deleted_users = await prisma_client.db.litellm_usertable.delete_many(
deleted_users = await UserRepository(prisma_client).table.delete_many(
where={"user_id": {"in": data.user_ids}}
)
@@ -2340,16 +2351,18 @@ async def add_internal_user_to_organization(
try:
# Check if organization_id exists
organization_row = await prisma_client.db.litellm_organizationtable.find_unique(
where={"organization_id": organization_id}
)
organization_row = await OrganizationRepository(
prisma_client
).table.find_unique(where={"organization_id": organization_id})
if organization_row is None:
raise Exception(
f"Organization not found, passed organization_id={organization_id}"
)
# Create a new organization membership entry
new_membership = await prisma_client.db.litellm_organizationmembership.create(
new_membership = await OrganizationMembershipRepository(
prisma_client
).table.create(
data={
"user_id": user_id,
"organization_id": organization_id,
@@ -2559,13 +2572,13 @@ async def ui_view_users(
}
# Query users with pagination and filters
users: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_usertable.find_many(
where=where_conditions,
skip=skip,
take=page_size,
order={"created_at": "desc"},
)
users: Optional[List[BaseModel]] = await UserRepository(
prisma_client
).table.find_many(
where=where_conditions,
skip=skip,
take=page_size,
order={"created_at": "desc"},
)
if not users:
@@ -11,6 +11,7 @@ from litellm.proxy._types import (
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
from litellm.repositories.table_repositories import JWTKeyMappingRepository
router = APIRouter()
@@ -61,7 +62,7 @@ async def create_jwt_key_mapping(
if data.description is not None:
create_data["description"] = data.description
new_mapping = await prisma_client.db.litellm_jwtkeymapping.create(
new_mapping = await JWTKeyMappingRepository(prisma_client).table.create(
data=create_data
)
@@ -113,7 +114,7 @@ async def update_jwt_key_mapping(
try:
# Get old mapping for cache invalidation
old_mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
old_mapping = await JWTKeyMappingRepository(prisma_client).table.find_unique(
where={"id": data.id}
)
@@ -123,7 +124,7 @@ async def update_jwt_key_mapping(
cache_key = f"jwt_key_mapping:{old_mapping.jwt_claim_name}:{old_mapping.jwt_claim_value}"
await user_api_key_cache.async_delete_cache(cache_key)
updated_mapping = await prisma_client.db.litellm_jwtkeymapping.update(
updated_mapping = await JWTKeyMappingRepository(prisma_client).table.update(
where={"id": data.id}, data=update_data
)
@@ -166,7 +167,7 @@ async def delete_jwt_key_mapping(
try:
# Get old mapping for cache invalidation
old_mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
old_mapping = await JWTKeyMappingRepository(prisma_client).table.find_unique(
where={"id": data.id}
)
@@ -176,7 +177,7 @@ async def delete_jwt_key_mapping(
cache_key = f"jwt_key_mapping:{old_mapping.jwt_claim_name}:{old_mapping.jwt_claim_value}"
await user_api_key_cache.async_delete_cache(cache_key)
await prisma_client.db.litellm_jwtkeymapping.delete(where={"id": data.id})
await JWTKeyMappingRepository(prisma_client).table.delete(where={"id": data.id})
return {"status": "success"}
except HTTPException:
raise
@@ -206,12 +207,12 @@ async def list_jwt_key_mappings(
try:
skip = (page - 1) * size
mappings = await prisma_client.db.litellm_jwtkeymapping.find_many(
mappings = await JWTKeyMappingRepository(prisma_client).table.find_many(
skip=skip,
take=size,
order={"created_at": "desc"},
)
total_count = await prisma_client.db.litellm_jwtkeymapping.count()
total_count = await JWTKeyMappingRepository(prisma_client).table.count()
return {
"mappings": [_to_response(m) for m in mappings],
"total_count": total_count,
@@ -245,7 +246,7 @@ async def info_jwt_key_mapping(
raise HTTPException(status_code=500, detail="Database not connected")
try:
mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
mapping = await JWTKeyMappingRepository(prisma_client).table.find_unique(
where={"id": id}
)
if mapping is None:
@@ -28,7 +28,6 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s
import litellm
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.constants import (
LENGTH_OF_LITELLM_GENERATED_KEY,
LITELLM_PROXY_ADMIN_NAME,
@@ -58,6 +57,7 @@ from litellm.proxy.common_utils.callback_utils import (
)
from litellm.proxy.common_utils.rbac_utils import check_org_admin_can_generate_keys
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_endpoints.common_utils import (
_check_passthrough_routes_caller_permission,
@@ -91,6 +91,19 @@ from litellm.proxy.utils import (
handle_exception_on_proxy,
is_valid_api_key,
)
from litellm.repositories.budget_repository import BudgetRepository
from litellm.repositories.config_repository import ConfigRepository
from litellm.repositories.credentials_repository import CredentialsRepository
from litellm.repositories.model_repository import ModelRepository
from litellm.repositories.table_repositories import (
DeletedVerificationTokenRepository,
DeprecatedVerificationTokenRepository,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.router import Router
from litellm.secret_managers.main import get_secret
from litellm.types.proxy.management_endpoints.key_management_endpoints import (
@@ -582,7 +595,7 @@ async def validate_team_id_used_in_service_account_request(
)
# check if team_id exists in the database
team = await prisma_client.db.litellm_teamtable.find_unique(
team = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id},
)
if team is None:
@@ -774,7 +787,7 @@ async def _common_key_generation_helper( # noqa: PLR0915
)
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
_budget = await prisma_client.db.litellm_budgettable.create(
_budget = await BudgetRepository(prisma_client).table.create(
data={
**new_budget, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
@@ -1144,7 +1157,7 @@ async def _check_team_key_limits(
# calculate allocated tpm/rpm limit
# check if specified tpm/rpm limit is greater than allocated tpm/rpm limit
keys = await prisma_client.db.litellm_verificationtoken.find_many(
keys = await VerificationTokenRepository(prisma_client).table.find_many(
where={"team_id": team_table.team_id},
)
# Exclude the key being updated to avoid double-counting its limits.
@@ -1294,7 +1307,7 @@ async def _validate_caller_can_assign_key_org(
detail="Cannot assign a key to an organization without a user_id on the caller's token",
)
user_row = await prisma_client.db.litellm_usertable.find_unique(
user_row = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
@@ -1339,7 +1352,7 @@ async def _check_org_key_limits(
# get all organization keys
# calculate allocated tpm/rpm limit
# check if specified tpm/rpm limit is greater than allocated tpm/rpm limit
keys = await prisma_client.db.litellm_verificationtoken.find_many(
keys = await VerificationTokenRepository(prisma_client).table.find_many(
where={"organization_id": org_table.organization_id},
)
# Exclude the key being updated to avoid double-counting its limits.
@@ -1968,9 +1981,9 @@ async def _get_and_validate_existing_key(
hashed_token = _hash_token_if_needed(token=token)
existing_key_row = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token}
)
existing_key_row = await VerificationTokenRepository(
prisma_client
).table.find_unique(where={"token": hashed_token})
if existing_key_row is None:
raise ProxyException(
@@ -2869,7 +2882,9 @@ async def bulk_update_team_keys(
# `blocked` is Boolean? with no default; `/key/generate` writes NULL. Prisma's `NOT`
# excludes NULLs, so explicitly OR `false` with `null` to include them.
now = datetime.now(timezone.utc)
existing_keys = await prisma_client.db.litellm_verificationtoken.find_many(
existing_keys = await VerificationTokenRepository(
prisma_client
).table.find_many(
where={
"team_id": data.team_id,
"AND": [
@@ -2907,7 +2922,9 @@ async def bulk_update_team_keys(
seen_hashes.add(h)
requested_tokens.append(k)
hashed_key_ids.append(h)
existing_keys = await prisma_client.db.litellm_verificationtoken.find_many(
existing_keys = await VerificationTokenRepository(
prisma_client
).table.find_many(
where={"team_id": data.team_id, "token": {"in": hashed_key_ids}}
)
@@ -3232,7 +3249,9 @@ async def info_key_fn_v2(
# Resolve key_aliases to tokens so we never pass token=None (unbounded query)
tokens_to_query = list(data.keys) if data.keys else []
if data.key_aliases:
alias_rows = await prisma_client.db.litellm_verificationtoken.find_many(
alias_rows = await VerificationTokenRepository(
prisma_client
).table.find_many(
where={"key_alias": {"in": data.key_aliases}},
include={"litellm_budget_table": True},
)
@@ -3311,7 +3330,7 @@ async def info_key_fn(
hashed_key: Optional[str] = key
if key is not None:
hashed_key = _hash_token_if_needed(token=key)
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
key_info = await VerificationTokenRepository(prisma_client).table.find_unique(
where={"token": hashed_key}, # type: ignore
include={"litellm_budget_table": True},
)
@@ -3851,7 +3870,7 @@ async def delete_verification_tokens(
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
await VerificationTokenRepository(prisma_client).table.find_many(
where={"token": {"in": tokens}}
)
)
@@ -3989,7 +4008,9 @@ async def _save_deleted_verification_token_records(
"""Save deleted verification token records to the database."""
if not records:
return
await prisma_client.db.litellm_deletedverificationtoken.create_many(data=records)
await DeletedVerificationTokenRepository(prisma_client).table.create_many(
data=records
)
async def _persist_deleted_verification_tokens(
@@ -4017,9 +4038,9 @@ async def delete_key_aliases(
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
_keys_being_deleted = await prisma_client.db.litellm_verificationtoken.find_many(
where={"key_alias": {"in": key_aliases}}
)
_keys_being_deleted = await VerificationTokenRepository(
prisma_client
).table.find_many(where={"key_alias": {"in": key_aliases}})
tokens = [key.token for key in _keys_being_deleted]
return await delete_verification_tokens(
@@ -4054,9 +4075,7 @@ async def _rotate_master_key( # noqa: PLR0915
from litellm.proxy.proxy_server import proxy_config
try:
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
models: Optional[List] = await ModelRepository(prisma_client).table.find_many()
except Exception:
models = None
# 2. process model table
@@ -4088,7 +4107,7 @@ async def _rotate_master_key( # noqa: PLR0915
)
# 3. process config table
try:
config = await prisma_client.db.litellm_config.find_many()
config = await ConfigRepository(prisma_client).table.find_many()
except Exception:
config = None
@@ -4109,7 +4128,7 @@ async def _rotate_master_key( # noqa: PLR0915
)
if encrypted_env_vars:
await prisma_client.db.litellm_config.update(
await ConfigRepository(prisma_client).table.update(
where={"param_name": "environment_variables"},
data={"param_value": prisma.Json(encrypted_env_vars)}, # type: ignore[attr-defined]
)
@@ -4148,7 +4167,7 @@ async def _rotate_master_key( # noqa: PLR0915
# 5. process credentials table
try:
credentials = await prisma_client.db.litellm_credentialstable.find_many()
credentials = await CredentialsRepository(prisma_client).table.find_many()
except Exception:
credentials = None
if credentials:
@@ -4171,7 +4190,7 @@ async def _rotate_master_key( # noqa: PLR0915
_cred_data["credential_info"] = prisma.Json( # type: ignore[attr-defined]
_cred_data["credential_info"]
)
await prisma_client.db.litellm_credentialstable.update(
await CredentialsRepository(prisma_client).table.update(
where={"credential_name": cred.credential_name},
data={
**_cred_data,
@@ -4243,7 +4262,7 @@ async def _insert_deprecated_key(
try:
revoke_at = datetime.now(timezone.utc) + timedelta(seconds=grace_seconds)
await prisma_client.db.litellm_deprecatedverificationtoken.upsert(
await DeprecatedVerificationTokenRepository(prisma_client).table.upsert(
where={"token": old_token_hash},
data={
"create": {
@@ -4335,7 +4354,7 @@ async def _execute_virtual_key_regeneration(
grace_period=data.grace_period if data else None,
)
updated_token = await prisma_client.db.litellm_verificationtoken.update(
updated_token = await VerificationTokenRepository(prisma_client).table.update(
where={"token": hashed_api_key},
data=update_data, # type: ignore
)
@@ -4530,7 +4549,7 @@ async def regenerate_key_fn( # noqa: PLR0915
else:
hashed_api_key = hash_token(key)
_key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique(
_key_in_db = await VerificationTokenRepository(prisma_client).table.find_unique(
where={"token": hashed_api_key},
)
if _key_in_db is None:
@@ -4719,7 +4738,7 @@ async def reset_key_spend_fn(
else:
hashed_api_key = hash_token(key)
_key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique(
_key_in_db = await VerificationTokenRepository(prisma_client).table.find_unique(
where={"token": hashed_api_key},
include={"litellm_budget_table": True},
)
@@ -4739,7 +4758,7 @@ async def reset_key_spend_fn(
user_api_key_cache=user_api_key_cache,
)
updated_key = await prisma_client.db.litellm_verificationtoken.update(
updated_key = await VerificationTokenRepository(prisma_client).table.update(
where={"token": hashed_api_key},
data={"spend": reset_to},
)
@@ -4792,11 +4811,11 @@ async def validate_key_list_check(
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_db_obj: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
complete_user_info_db_obj: Optional[BaseModel] = await UserRepository(
prisma_client
).table.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
if complete_user_info_db_obj is None:
@@ -4846,7 +4865,9 @@ async def validate_key_list_check(
if key_hash:
try:
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
key_info = await VerificationTokenRepository(
prisma_client
).table.find_unique(
where={"token": key_hash},
)
except Exception:
@@ -4879,11 +4900,9 @@ async def _fetch_user_team_objects(
if complete_user_info is None or not complete_user_info.teams:
return []
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
)
teams: Optional[List[BaseModel]] = await TeamRepository(
prisma_client
).table.find_many(where={"team_id": {"in": complete_user_info.teams}})
if teams is None:
return []
@@ -5160,7 +5179,7 @@ async def _apply_non_admin_alias_scope(
# Look up the user's teams from the user table
user_teams: List[str] = []
if user_api_key_dict.user_id:
user_row = await prisma_client.db.litellm_usertable.find_unique(
user_row = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_api_key_dict.user_id}
)
if user_row is not None:
@@ -5548,7 +5567,7 @@ async def _list_key_helper(
# Fetch keys with pagination
if use_deleted_table:
keys = await prisma_client.db.litellm_deletedverificationtoken.find_many(
keys = await DeletedVerificationTokenRepository(prisma_client).table.find_many(
where=where, # type: ignore
skip=skip, # type: ignore
take=size, # type: ignore
@@ -5562,7 +5581,7 @@ async def _list_key_helper(
),
)
else:
keys = await prisma_client.db.litellm_verificationtoken.find_many(
keys = await VerificationTokenRepository(prisma_client).table.find_many(
where=where, # type: ignore
skip=skip, # type: ignore
take=size, # type: ignore
@@ -5581,11 +5600,13 @@ async def _list_key_helper(
# Get total count of keys
if use_deleted_table:
total_count = await prisma_client.db.litellm_deletedverificationtoken.count(
total_count = await DeletedVerificationTokenRepository(
prisma_client
).table.count(
where=where # type: ignore
)
else:
total_count = await prisma_client.db.litellm_verificationtoken.count(
total_count = await VerificationTokenRepository(prisma_client).table.count(
where=where # type: ignore
)
@@ -5601,7 +5622,7 @@ async def _list_key_helper(
created_by_ids = [key.created_by for key in keys if key.created_by]
all_ids = list(set(user_ids + created_by_ids)) # Remove duplicates
if all_ids:
users = await prisma_client.db.litellm_usertable.find_many(
users = await UserRepository(prisma_client).table.find_many(
where={"user_id": {"in": all_ids}}
)
user_map = {user.user_id: user for user in users}
@@ -5688,7 +5709,7 @@ async def _check_key_admin_access(
return
# Look up the target key to find its team
target_key_row = await prisma_client.db.litellm_verificationtoken.find_unique(
target_key_row = await VerificationTokenRepository(prisma_client).table.find_unique(
where={"token": hashed_token}
)
if target_key_row is None:
@@ -5755,6 +5776,9 @@ async def block_key(
Note: This is an admin-only endpoint. Only proxy admins, team admins, or org admins can block keys.
"""
from litellm.proxy.management_helpers.audit_logs import (
get_audit_log_changed_by,
)
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
hash_token,
@@ -5763,9 +5787,6 @@ async def block_key(
proxy_logging_obj,
user_api_key_cache,
)
from litellm.proxy.management_helpers.audit_logs import (
get_audit_log_changed_by,
)
if prisma_client is None:
raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value))
@@ -5792,9 +5813,9 @@ async def block_key(
)
# Check if the key exists before trying to block it
existing_record = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token}
)
existing_record = await VerificationTokenRepository(
prisma_client
).table.find_unique(where={"token": hashed_token})
if existing_record is None:
raise ProxyException(
message="Key not found.",
@@ -5824,7 +5845,7 @@ async def block_key(
)
)
record = await prisma_client.db.litellm_verificationtoken.update(
record = await VerificationTokenRepository(prisma_client).table.update(
where={"token": hashed_token}, data={"blocked": True} # type: ignore
)
@@ -5869,6 +5890,9 @@ async def unblock_key(
Note: This is an admin-only endpoint. Only proxy admins, team admins, or org admins can unblock keys.
"""
from litellm.proxy.management_helpers.audit_logs import (
get_audit_log_changed_by,
)
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
hash_token,
@@ -5877,9 +5901,6 @@ async def unblock_key(
proxy_logging_obj,
user_api_key_cache,
)
from litellm.proxy.management_helpers.audit_logs import (
get_audit_log_changed_by,
)
if prisma_client is None:
raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value))
@@ -5906,9 +5927,9 @@ async def unblock_key(
)
# Check if the key exists before trying to unblock it
existing_record = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token}
)
existing_record = await VerificationTokenRepository(
prisma_client
).table.find_unique(where={"token": hashed_token})
if existing_record is None:
raise ProxyException(
message="Key not found.",
@@ -5938,7 +5959,7 @@ async def unblock_key(
)
)
record = await prisma_client.db.litellm_verificationtoken.update(
record = await VerificationTokenRepository(prisma_client).table.update(
where={"token": hashed_token}, data={"blocked": False} # type: ignore
)
@@ -6202,9 +6223,9 @@ async def _enforce_unique_key_alias(
# Exclude the current key from the uniqueness check
where_clause["NOT"] = {"token": existing_key_token}
existing_key = await prisma_client.db.litellm_verificationtoken.find_first(
where=where_clause
)
existing_key = await VerificationTokenRepository(
prisma_client
).table.find_first(where=where_clause)
if existing_key is not None:
raise ProxyException(
message=f"Key with alias '{key_alias}' already exists. Unique key aliases across all keys are required.",
@@ -60,6 +60,10 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import (
encrypt_value_helper,
)
from litellm.proxy.management_helpers.audit_logs import get_audit_log_changed_by
from litellm.repositories.table_repositories import (
MCPServerRepository,
MCPUserCredentialsRepository,
)
router = APIRouter(prefix="/v1/mcp", tags=["mcp"])
@@ -761,7 +765,7 @@ if MCP_AVAILABLE:
# Get from DB
if prisma_client is not None:
try:
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
mcp_servers = await MCPServerRepository(prisma_client).table.find_many()
for server in mcp_servers:
if (
hasattr(server, "mcp_access_groups")
@@ -998,10 +1002,10 @@ if MCP_AVAILABLE:
if getattr(s, "is_byok", False)
]
if byok_server_ids:
cred_rows = (
await _byok_prisma_client.db.litellm_mcpusercredentials.find_many(
where={"user_id": user_id, "server_id": {"in": byok_server_ids}}
)
cred_rows = await MCPUserCredentialsRepository(
_byok_prisma_client
).table.find_many(
where={"user_id": user_id, "server_id": {"in": byok_server_ids}}
)
cred_set = {r.server_id for r in cred_rows}
for server in redacted_mcp_servers:
@@ -19,6 +19,7 @@ from litellm.proxy.management_endpoints.model_management_endpoints import (
clear_cache,
)
from litellm.proxy.utils import PrismaClient
from litellm.repositories.model_repository import ModelRepository
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
AccessGroupInfo,
DeleteModelGroupResponse,
@@ -95,7 +96,7 @@ async def update_deployments_with_access_group(
verbose_proxy_logger.debug(f"Updating deployments for model_name: {model_name}")
# Get all deployments with this model_name
deployments = await prisma_client.db.litellm_proxymodeltable.find_many(
deployments = await ModelRepository(prisma_client).table.find_many(
where={"model_name": model_name}
)
@@ -124,7 +125,7 @@ async def update_deployments_with_access_group(
# Only update in DB if modified
if was_modified:
await prisma_client.db.litellm_proxymodeltable.update(
await ModelRepository(prisma_client).table.update(
where={"model_id": deployment.model_id},
data={"model_info": json.dumps(updated_model_info)},
)
@@ -152,7 +153,7 @@ async def update_specific_deployments_with_access_group(
models_updated = 0
for model_id in model_ids:
verbose_proxy_logger.debug(f"Updating specific deployment model_id: {model_id}")
deployment = await prisma_client.db.litellm_proxymodeltable.find_unique(
deployment = await ModelRepository(prisma_client).table.find_unique(
where={"model_id": model_id}
)
if deployment is None:
@@ -168,7 +169,7 @@ async def update_specific_deployments_with_access_group(
access_group=access_group,
)
if was_modified:
await prisma_client.db.litellm_proxymodeltable.update(
await ModelRepository(prisma_client).table.update(
where={"model_id": model_id},
data={"model_info": json.dumps(updated_model_info)},
)
@@ -215,7 +216,7 @@ async def get_all_access_groups_from_db(
Dict[str, AccessGroupInfo]: Dictionary mapping access_group name to info
"""
# Get all deployments
deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
deployments = await ModelRepository(prisma_client).table.find_many()
# Build access group map
access_group_map: Dict[str, Dict[str, Any]] = {}
@@ -604,7 +605,7 @@ async def update_access_group(
try:
# Step 1: Remove access group from ALL DB deployments (skip config models)
all_deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
all_deployments = await ModelRepository(prisma_client).table.find_many()
for deployment in all_deployments:
model_info = deployment.model_info or {}
@@ -615,7 +616,7 @@ async def update_access_group(
)
if was_modified:
await prisma_client.db.litellm_proxymodeltable.update(
await ModelRepository(prisma_client).table.update(
where={"model_id": deployment.model_id},
data={"model_info": json.dumps(updated_model_info)},
)
@@ -722,7 +723,7 @@ async def delete_access_group(
try:
# Remove access group from all DB deployments (skip config models)
all_deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
all_deployments = await ModelRepository(prisma_client).table.find_many()
models_updated = 0
for deployment in all_deployments:
@@ -734,7 +735,7 @@ async def delete_access_group(
)
if was_modified:
await prisma_client.db.litellm_proxymodeltable.update(
await ModelRepository(prisma_client).table.update(
where={"model_id": deployment.model_id},
data={"model_info": json.dumps(updated_model_info)},
)
@@ -48,6 +48,9 @@ from litellm.proxy.management_endpoints.team_endpoints import (
)
from litellm.proxy.management_helpers.audit_logs import create_object_audit_log
from litellm.proxy.utils import PrismaClient
from litellm.repositories.model_repository import ModelRepository
from litellm.repositories.table_repositories import ModelTableRepository
from litellm.repositories.team_repository import TeamRepository
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
UpdateUsefulLinksRequest,
)
@@ -86,7 +89,7 @@ async def get_db_model(
) -> Optional[Deployment]:
db_model = cast(
Optional[BaseModel],
await prisma_client.db.litellm_proxymodeltable.find_unique(
await ModelRepository(prisma_client).table.find_unique(
where={"model_id": model_id}
),
)
@@ -290,7 +293,7 @@ async def patch_model(
update_data["updated_at"] = cast(str, get_utc_datetime())
# Perform partial update
updated_model = await prisma_client.db.litellm_proxymodeltable.update(
updated_model = await ModelRepository(prisma_client).table.update(
where={"model_id": model_id},
data=update_data,
)
@@ -362,7 +365,7 @@ async def _add_model_to_db(
if model_params.model_info.id is not None:
_data["model_id"] = model_params.model_info.id
if should_create_model_in_db:
model_response = await prisma_client.db.litellm_proxymodeltable.create(
model_response = await ModelRepository(prisma_client).table.create(
data=_data # type: ignore
)
else:
@@ -571,7 +574,7 @@ async def _get_team_deployments(
team_id in model_info with Python-side filtering.
"""
prefix = f"model_name_{team_id}_"
response = await prisma_client.db.litellm_proxymodeltable.find_many(
response = await ModelRepository(prisma_client).table.find_many(
where={
"model_name": {"startswith": prefix},
}
@@ -828,7 +831,7 @@ class ModelManagementAuthChecks:
detail={"error": CommonProxyErrors.not_premium_user.value},
)
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
_existing_team_row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": model_params.model_info.team_id}
)
@@ -863,7 +866,7 @@ class ModelManagementAuthChecks:
model_params.model_info is not None
and model_params.model_info.team_id is not None
):
team_obj_row = await prisma_client.db.litellm_teamtable.find_unique(
team_obj_row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": model_params.model_info.team_id}
)
if team_obj_row is None:
@@ -937,7 +940,7 @@ async def delete_model(
},
)
model_in_db = await prisma_client.db.litellm_proxymodeltable.find_unique(
model_in_db = await ModelRepository(prisma_client).table.find_unique(
where={"model_id": model_info.id}
)
if model_in_db is None:
@@ -961,7 +964,7 @@ async def delete_model(
- store keys separately
"""
# encrypt litellm params #
result = await prisma_client.db.litellm_proxymodeltable.delete(
result = await ModelRepository(prisma_client).table.delete(
where={"model_id": model_info.id}
)
@@ -1039,7 +1042,7 @@ async def delete_team_model_alias(
Returns:
- List of team id + model alias pairs that were removed
"""
team_model_aliases = await prisma_client.db.litellm_modeltable.find_many(
team_model_aliases = await ModelTableRepository(prisma_client).table.find_many(
include={"team": True}
)
tasks = []
@@ -1056,7 +1059,7 @@ async def delete_team_model_alias(
removed_model_aliases.append((team_model_alias.team.team_id, key))
del model_aliases[key]
tasks.append(
prisma_client.db.litellm_modeltable.update(
ModelTableRepository(prisma_client).table.update(
where={"id": id},
data={"model_aliases": json.dumps(model_aliases)},
)
@@ -1275,11 +1278,9 @@ async def update_model(
if _model_id is None:
raise Exception("model_info.id not provided")
_existing_litellm_params = (
await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": _model_id}
)
)
_existing_litellm_params = await ModelRepository(
prisma_client
).table.find_unique(where={"model_id": _model_id})
if _existing_litellm_params is None:
if (
@@ -1340,7 +1341,7 @@ async def update_model(
"litellm_params": json.dumps(merged_dictionary), # type: ignore
"updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
}
model_response = await prisma_client.db.litellm_proxymodeltable.update(
model_response = await ModelRepository(prisma_client).table.update(
where={"model_id": _model_id},
data=_data, # type: ignore
)
@@ -40,6 +40,15 @@ from litellm.proxy.management_helpers.utils import (
management_endpoint_wrapper,
)
from litellm.proxy.utils import PrismaClient
from litellm.repositories.budget_repository import BudgetRepository
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
from litellm.repositories.organization_repository import OrganizationRepository
from litellm.repositories.table_repositories import OrganizationMembershipRepository
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.types.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
)
@@ -245,7 +254,7 @@ async def new_organization(
if user_api_key_dict.user_id is not None:
try:
user_object = await prisma_client.db.litellm_usertable.find_unique(
user_object = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_api_key_dict.user_id}
)
user_object_correct_type = LiteLLM_UserTable(**user_object.model_dump())
@@ -267,7 +276,7 @@ async def new_organization(
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
_budget = await prisma_client.db.litellm_budgettable.create(
_budget = await BudgetRepository(prisma_client).table.create(
data={
**new_budget, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
@@ -323,7 +332,7 @@ async def new_organization(
verbose_proxy_logger.info(
f"new_organization_row: {json.dumps(new_organization_row, indent=2)}"
)
response = await prisma_client.db.litellm_organizationtable.create(
response = await OrganizationRepository(prisma_client).table.create(
data={
**new_organization_row, # type: ignore
},
@@ -372,9 +381,9 @@ async def get_organization_daily_activity(
# Restrict non-proxy-admins to only organizations where they are org_admin
if not _user_has_admin_view(user_api_key_dict):
memberships = await prisma_client.db.litellm_organizationmembership.find_many(
where={"user_id": user_api_key_dict.user_id}
)
memberships = await OrganizationMembershipRepository(
prisma_client
).table.find_many(where={"user_id": user_api_key_dict.user_id})
admin_org_ids = [
m.organization_id
for m in memberships
@@ -400,7 +409,7 @@ async def get_organization_daily_activity(
where_condition = {}
if org_ids_list:
where_condition["organization_id"] = {"in": list(org_ids_list)}
org_aliases = await prisma_client.db.litellm_organizationtable.find_many(
org_aliases = await OrganizationRepository(prisma_client).table.find_many(
where=where_condition
)
org_alias_metadata = {
@@ -439,10 +448,10 @@ async def _set_object_permission(
return None
if data.object_permission is not None:
created_object_permission = (
await prisma_client.db.litellm_objectpermissiontable.create(
data=data.object_permission.model_dump(exclude_none=True),
)
created_object_permission = await ObjectPermissionRepository(
prisma_client
).table.create(
data=data.object_permission.model_dump(exclude_none=True),
)
del data.object_permission
return created_object_permission.object_permission_id
@@ -525,10 +534,10 @@ async def update_organization(
prisma_client=prisma_client,
)
existing_organization_row = (
await prisma_client.db.litellm_organizationtable.find_unique(
where={"organization_id": data.organization_id},
)
existing_organization_row = await OrganizationRepository(
prisma_client
).table.find_unique(
where={"organization_id": data.organization_id},
)
if existing_organization_row is None:
@@ -574,7 +583,7 @@ async def update_organization(
for field in LiteLLM_BudgetTable.model_fields.keys():
updated_organization_row.pop(field, None)
response = await prisma_client.db.litellm_organizationtable.update(
response = await OrganizationRepository(prisma_client).table.update(
where={"organization_id": data.organization_id},
data=updated_organization_row,
include={"members": True, "teams": True, "litellm_budget_table": True},
@@ -644,19 +653,19 @@ async def delete_organization(
deleted_orgs = []
for organization_id in data.organization_ids:
# delete all teams in the organization
await prisma_client.db.litellm_teamtable.delete_many(
await TeamRepository(prisma_client).table.delete_many(
where={"organization_id": organization_id}
)
# delete all members in the organization
await prisma_client.db.litellm_organizationmembership.delete_many(
await OrganizationMembershipRepository(prisma_client).table.delete_many(
where={"organization_id": organization_id}
)
# delete all keys in the organization
await prisma_client.db.litellm_verificationtoken.delete_many(
await VerificationTokenRepository(prisma_client).table.delete_many(
where={"organization_id": organization_id}
)
# delete the organization
deleted_org = await prisma_client.db.litellm_organizationtable.delete(
deleted_org = await OrganizationRepository(prisma_client).table.delete(
where={"organization_id": organization_id},
include={"members": True, "teams": True, "litellm_budget_table": True},
)
@@ -732,17 +741,15 @@ async def list_organization(
# if proxy admin or admin viewer - get all orgs (with optional filters)
if _user_has_admin_view(user_api_key_dict):
response = await prisma_client.db.litellm_organizationtable.find_many(
response = await OrganizationRepository(prisma_client).table.find_many(
where=where_conditions if where_conditions else None,
include={"litellm_budget_table": True, "members": True, "teams": True},
)
# if internal user - get orgs they are a member of (with optional filters)
else:
org_memberships = (
await prisma_client.db.litellm_organizationmembership.find_many(
where={"user_id": user_api_key_dict.user_id}
)
)
org_memberships = await OrganizationMembershipRepository(
prisma_client
).table.find_many(where={"user_id": user_api_key_dict.user_id})
membership_org_ids = [
membership.organization_id for membership in org_memberships
]
@@ -756,20 +763,20 @@ async def list_organization(
response = []
else:
where_conditions["organization_id"] = org_id
response = (
await prisma_client.db.litellm_organizationtable.find_many(
where=where_conditions,
include={
"litellm_budget_table": True,
"members": True,
"teams": True,
},
)
response = await OrganizationRepository(
prisma_client
).table.find_many(
where=where_conditions,
include={
"litellm_budget_table": True,
"members": True,
"teams": True,
},
)
else:
# Filter by membership and any additional filters
where_conditions["organization_id"] = {"in": membership_org_ids}
response = await prisma_client.db.litellm_organizationtable.find_many(
response = await OrganizationRepository(prisma_client).table.find_many(
where=where_conditions,
include={
"litellm_budget_table": True,
@@ -809,20 +816,20 @@ async def info_organization(
prisma_client=prisma_client,
)
response: Optional[LiteLLM_OrganizationTableWithMembers] = (
await prisma_client.db.litellm_organizationtable.find_unique(
where={"organization_id": organization_id},
include={
"litellm_budget_table": True,
"members": {
"include": {
"user": True,
}
},
"teams": True,
"object_permission": True,
response: Optional[
LiteLLM_OrganizationTableWithMembers
] = await OrganizationRepository(prisma_client).table.find_unique(
where={"organization_id": organization_id},
include={
"litellm_budget_table": True,
"members": {
"include": {
"user": True,
}
},
)
"teams": True,
"object_permission": True,
},
)
if response is None:
@@ -868,7 +875,7 @@ async def deprecated_info_organization(
prisma_client=prisma_client,
)
response = await prisma_client.db.litellm_organizationtable.find_many(
response = await OrganizationRepository(prisma_client).table.find_many(
where={"organization_id": {"in": data.organizations}},
include={"litellm_budget_table": True},
)
@@ -945,11 +952,9 @@ async def organization_member_add(
)
# Check if organization exists
existing_organization_row = (
await prisma_client.db.litellm_organizationtable.find_unique(
where={"organization_id": data.organization_id}
)
)
existing_organization_row = await OrganizationRepository(
prisma_client
).table.find_unique(where={"organization_id": data.organization_id})
if existing_organization_row is None:
raise HTTPException(
status_code=404,
@@ -1012,11 +1017,9 @@ async def find_member_if_email(
"""
try:
existing_user_email_row: BaseModel = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_email": user_email}
)
)
existing_user_email_row: BaseModel = await UserRepository(
prisma_client
).table.find_unique(where={"user_email": user_email})
except Exception:
raise HTTPException(
status_code=400,
@@ -1064,11 +1067,9 @@ async def organization_member_update(
)
# Check if organization exists
existing_organization_row = (
await prisma_client.db.litellm_organizationtable.find_unique(
where={"organization_id": data.organization_id}
)
)
existing_organization_row = await OrganizationRepository(
prisma_client
).table.find_unique(where={"organization_id": data.organization_id})
if existing_organization_row is None:
raise HTTPException(
status_code=400,
@@ -1085,15 +1086,15 @@ async def organization_member_update(
data.user_id = existing_user_email_row.user_id
try:
existing_organization_membership = (
await prisma_client.db.litellm_organizationmembership.find_unique(
where={
"user_id_organization_id": {
"user_id": data.user_id,
"organization_id": data.organization_id,
}
existing_organization_membership = await OrganizationMembershipRepository(
prisma_client
).table.find_unique(
where={
"user_id_organization_id": {
"user_id": data.user_id,
"organization_id": data.organization_id,
}
)
}
)
except Exception as e:
raise HTTPException(
@@ -1114,7 +1115,7 @@ async def organization_member_update(
# org-scoped operations. An org-admin of any org could otherwise
# alter a PROXY_ADMIN user's per-org role, which has downstream
# effects on admin UI filtering and scope derivation.
target_user_row = await prisma_client.db.litellm_usertable.find_unique(
target_user_row = await UserRepository(prisma_client).table.find_unique(
where={"user_id": data.user_id}
)
if target_user_row is not None and getattr(
@@ -1136,7 +1137,7 @@ async def organization_member_update(
# Update member role
if data.role is not None:
await prisma_client.db.litellm_organizationmembership.update(
await OrganizationMembershipRepository(prisma_client).table.update(
where={
"user_id_organization_id": {
"user_id": data.user_id,
@@ -1165,7 +1166,7 @@ async def organization_member_update(
)
# update organization membership with new budget_id
await prisma_client.db.litellm_organizationmembership.update(
await OrganizationMembershipRepository(prisma_client).table.update(
where={
"user_id_organization_id": {
"user_id": data.user_id,
@@ -1174,16 +1175,16 @@ async def organization_member_update(
},
data={"budget_id": budget_id},
)
final_organization_membership: Optional[BaseModel] = (
await prisma_client.db.litellm_organizationmembership.find_unique(
where={
"user_id_organization_id": {
"user_id": data.user_id,
"organization_id": data.organization_id,
}
},
include={"litellm_budget_table": True},
)
final_organization_membership: Optional[
BaseModel
] = await OrganizationMembershipRepository(prisma_client).table.find_unique(
where={
"user_id_organization_id": {
"user_id": data.user_id,
"organization_id": data.organization_id,
}
},
include={"litellm_budget_table": True},
)
if final_organization_membership is None:
@@ -1239,7 +1240,9 @@ async def organization_member_delete(
)
data.user_id = existing_user_email_row.user_id
member_to_delete = await prisma_client.db.litellm_organizationmembership.delete(
member_to_delete = await OrganizationMembershipRepository(
prisma_client
).table.delete(
where={
"user_id_organization_id": {
"user_id": data.user_id,
@@ -1273,17 +1276,15 @@ async def add_member_to_organization(
existing_user_email_row = None
## Check if user exists in LiteLLM_UserTable - user exists - either the user_id or user_email is in LiteLLM_UserTable
if member.user_id is not None:
existing_user_id_row = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": member.user_id}
)
existing_user_id_row = await UserRepository(
prisma_client
).table.find_unique(where={"user_id": member.user_id})
if existing_user_id_row is None and member.user_email is not None:
try:
existing_user_email_row = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_email": member.user_email}
)
)
existing_user_email_row = await UserRepository(
prisma_client
).table.find_unique(where={"user_email": member.user_email})
except Exception as e:
raise ValueError(
f"Potential NON-Existent or Duplicate user email in DB: Error finding a unique instance of user_email={member.user_email} in LiteLLM_UserTable.: {e}"
@@ -1326,14 +1327,14 @@ async def add_member_to_organization(
)
# Add user to organization
_organization_membership = (
await prisma_client.db.litellm_organizationmembership.create(
data={
"organization_id": organization_id,
"user_id": user_object.user_id,
"user_role": member.role,
}
)
_organization_membership = await OrganizationMembershipRepository(
prisma_client
).table.create(
data={
"organization_id": organization_id,
"user_id": user_object.user_id,
"user_role": member.role,
}
)
organization_membership = LiteLLM_OrganizationMembershipTable(
**_organization_membership.model_dump()
@@ -6,6 +6,7 @@ from litellm.proxy._types import (
Member,
NewUserResponse,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.types.proxy.management_endpoints.scim_v2 import *
@@ -29,7 +30,7 @@ class ScimTransformations:
# Get user's teams/groups
groups = []
for team_id in user.teams or []:
team = await prisma_client.db.litellm_teamtable.find_unique(
team = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
if team:
@@ -22,7 +22,6 @@ from typing_extensions import TypedDict
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
from litellm._uuid import uuid
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import (
@@ -41,6 +40,7 @@ from litellm.proxy._types import (
)
from litellm.proxy.auth.auth_checks import _delete_cache_key_object
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
from litellm.proxy.management_endpoints.scim.scim_transformations import (
ScimTransformations,
@@ -51,6 +51,16 @@ from litellm.proxy.management_endpoints.team_endpoints import (
team_member_delete,
)
from litellm.proxy.utils import _premium_user_check, handle_exception_on_proxy
from litellm.repositories.table_repositories import (
InvitationLinkRepository,
OrganizationMembershipRepository,
TeamMembershipRepository,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.types.proxy.management_endpoints.scim_v2 import *
@@ -74,7 +84,7 @@ class UserProvisionerHelpers:
if not new_user_request.user_email:
return None
existing_user = await prisma_client.db.litellm_usertable.find_first(
existing_user = await UserRepository(prisma_client).table.find_first(
where={"user_email": new_user_request.user_email}
)
@@ -82,7 +92,7 @@ class UserProvisionerHelpers:
return None
# Update the user
updated_user = await prisma_client.db.litellm_usertable.update(
updated_user = await UserRepository(prisma_client).table.update(
where={"user_id": existing_user.user_id},
data={
"user_id": new_user_request.user_id,
@@ -139,7 +149,7 @@ async def _check_user_exists(user_id: str):
"""Check if user exists and return user, raise 404 if not found."""
prisma_client = await _get_prisma_client_or_raise_exception()
user = await prisma_client.db.litellm_usertable.find_unique(
user = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_id}
)
@@ -155,7 +165,7 @@ async def _check_team_exists(team_id: str):
"""Check if team exists and return team, raise 404 if not found."""
prisma_client = await _get_prisma_client_or_raise_exception()
team = await prisma_client.db.litellm_teamtable.find_unique(
team = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
@@ -268,7 +278,7 @@ async def _extract_group_member_ids(group: SCIMGroup) -> GroupMemberExtractionRe
)
# Check if user exists
user = await prisma_client.db.litellm_usertable.find_unique(
user = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_id}
)
@@ -310,7 +320,7 @@ async def _get_team_members_display(member_ids: List[str]) -> List[SCIMMember]:
members: List[SCIMMember] = []
for member_id in member_ids:
user = await prisma_client.db.litellm_usertable.find_unique(
user = await UserRepository(prisma_client).table.find_unique(
where={"user_id": member_id}
)
if user:
@@ -367,7 +377,7 @@ async def _set_user_keys_blocked(user_id: str, blocked: bool) -> int:
# `blocked` is a nullable column with no default, so existing rows
# typically hold NULL; treat NULL as "not blocked" since SQL equality
# on NULL would otherwise silently skip them.
candidates = await prisma_client.db.litellm_verificationtoken.find_many(
candidates = await VerificationTokenRepository(prisma_client).table.find_many(
where={
"user_id": user_id,
"OR": [{"blocked": False}, {"blocked": None}],
@@ -375,7 +385,7 @@ async def _set_user_keys_blocked(user_id: str, blocked: bool) -> int:
)
affected_keys = candidates
else:
candidates = await prisma_client.db.litellm_verificationtoken.find_many(
candidates = await VerificationTokenRepository(prisma_client).table.find_many(
where={"user_id": user_id, "blocked": True},
)
affected_keys = [k for k in candidates if _key_was_scim_blocked(k.metadata)]
@@ -395,7 +405,7 @@ async def _set_user_keys_blocked(user_id: str, blocked: bool) -> int:
for k, v in current_metadata.items()
if k != SCIM_BLOCKED_METADATA_KEY
}
await prisma_client.db.litellm_verificationtoken.update(
await VerificationTokenRepository(prisma_client).table.update(
where={"token": key_row.token},
data={"blocked": blocked, "metadata": safe_dumps(new_metadata)},
)
@@ -423,7 +433,7 @@ async def _delete_rows_referencing_user(prisma_client: Any, *, user_id: str) ->
the user delete with an FK constraint violation (e.g.
``LiteLLM_InvitationLink_user_id_fkey``).
"""
await prisma_client.db.litellm_invitationlink.delete_many(
await InvitationLinkRepository(prisma_client).table.delete_many(
where={
"OR": [
{"user_id": user_id},
@@ -432,10 +442,10 @@ async def _delete_rows_referencing_user(prisma_client: Any, *, user_id: str) ->
]
}
)
await prisma_client.db.litellm_organizationmembership.delete_many(
await OrganizationMembershipRepository(prisma_client).table.delete_many(
where={"user_id": user_id}
)
await prisma_client.db.litellm_teammembership.delete_many(
await TeamMembershipRepository(prisma_client).table.delete_many(
where={"user_id": user_id}
)
@@ -897,17 +907,17 @@ async def get_users(
where_conditions["user_email"] = filter_value
# Get users from database
users: List[LiteLLM_UserTable] = (
await prisma_client.db.litellm_usertable.find_many(
where=where_conditions,
skip=(startIndex - 1),
take=count,
order={"created_at": "desc"},
)
users: List[LiteLLM_UserTable] = await UserRepository(
prisma_client
).table.find_many(
where=where_conditions,
skip=(startIndex - 1),
take=count,
order={"created_at": "desc"},
)
# Get total count for pagination
total_count = await prisma_client.db.litellm_usertable.count(
total_count = await UserRepository(prisma_client).table.count(
where=where_conditions
)
@@ -975,7 +985,7 @@ async def create_user(
# Check if user already exists
if user.userName:
existing_user = await prisma_client.db.litellm_usertable.find_unique(
existing_user = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user.userName}
)
if existing_user:
@@ -1094,7 +1104,7 @@ async def update_user(
"metadata": safe_dumps(metadata),
}
updated_user = await prisma_client.db.litellm_usertable.update(
updated_user = await UserRepository(prisma_client).table.update(
where={"user_id": user_id},
data=update_data,
)
@@ -1137,7 +1147,7 @@ async def delete_user(
teams = []
if existing_user.teams:
for team_id in existing_user.teams:
team = await prisma_client.db.litellm_teamtable.find_unique(
team = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
if team:
@@ -1148,7 +1158,7 @@ async def delete_user(
current_members = team.members or []
if user_id in current_members:
new_members = [m for m in current_members if m != user_id]
await prisma_client.db.litellm_teamtable.update(
await TeamRepository(prisma_client).table.update(
where={"team_id": team.team_id}, data={"members": new_members}
)
@@ -1157,7 +1167,7 @@ async def delete_user(
await _delete_rows_referencing_user(prisma_client, user_id=user_id)
# Delete user
await prisma_client.db.litellm_usertable.delete(where={"user_id": user_id})
await UserRepository(prisma_client).table.delete(where={"user_id": user_id})
return Response(status_code=204)
except Exception as e:
@@ -1413,7 +1423,7 @@ async def patch_user(
update_data["metadata"] = safe_dumps(update_data["metadata"])
updated_user = await prisma_client.db.litellm_usertable.update(
updated_user = await UserRepository(prisma_client).table.update(
where={"user_id": user_id},
data=update_data,
)
@@ -1465,7 +1475,7 @@ async def get_groups(
where_conditions["team_alias"] = team_alias
# Get teams from database
teams = await prisma_client.db.litellm_teamtable.find_many(
teams = await TeamRepository(prisma_client).table.find_many(
where=where_conditions,
skip=(startIndex - 1),
take=count,
@@ -1473,7 +1483,7 @@ async def get_groups(
)
# Get total count for pagination
total_count = await prisma_client.db.litellm_teamtable.count(
total_count = await TeamRepository(prisma_client).table.count(
where=where_conditions
)
@@ -1561,7 +1571,7 @@ async def create_group(
team_id = group.id or group.externalId or str(uuid.uuid4())
# Check if team already exists
existing_team = await prisma_client.db.litellm_teamtable.find_unique(
existing_team = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
@@ -1638,7 +1648,7 @@ async def update_group(
}
# Update team in database
updated_team = await prisma_client.db.litellm_teamtable.update(
updated_team = await TeamRepository(prisma_client).table.update(
where={"team_id": group_id},
data=update_data,
)
@@ -1683,19 +1693,19 @@ async def delete_group(
# For each member, remove this team from their teams list
for member_id in existing_team.members or []:
user = await prisma_client.db.litellm_usertable.find_unique(
user = await UserRepository(prisma_client).table.find_unique(
where={"user_id": member_id}
)
if user:
current_teams = user.teams or []
if group_id in current_teams:
new_teams = [t for t in current_teams if t != group_id]
await prisma_client.db.litellm_usertable.update(
await UserRepository(prisma_client).table.update(
where={"user_id": member_id}, data={"teams": new_teams}
)
# Delete team
await prisma_client.db.litellm_teamtable.delete(where={"team_id": group_id})
await TeamRepository(prisma_client).table.delete(where={"team_id": group_id})
return Response(status_code=204)
@@ -1748,7 +1758,7 @@ async def _process_group_patch_operations(
detail={"error": "Invalid member: user ID cannot be empty."},
)
user = await prisma_client.db.litellm_usertable.find_unique(
user = await UserRepository(prisma_client).table.find_unique(
where={"user_id": member_id}
)
if user:
@@ -1805,7 +1815,7 @@ async def _apply_group_patch_updates(
update_data["members"] = list(final_members)
# Update team in database
updated_team = await prisma_client.db.litellm_teamtable.update(
updated_team = await TeamRepository(prisma_client).table.update(
where={"team_id": group_id},
data=update_data,
)
@@ -1877,7 +1887,7 @@ async def patch_group(
# Refresh team data from database to get the latest state after concurrent updates
# This prevents race conditions when multiple PATCH requests come in simultaneously
refreshed_team = await prisma_client.db.litellm_teamtable.find_unique(
refreshed_team = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": group_id}
)
if refreshed_team:
@@ -1894,7 +1904,7 @@ async def patch_group(
await _handle_group_membership_changes(group_id, current_members, final_members)
# Refresh team one more time to get final state after membership changes
final_team = await prisma_client.db.litellm_teamtable.find_unique(
final_team = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": group_id}
)
if final_team:
@@ -25,6 +25,14 @@ from litellm.proxy.management_endpoints.common_daily_activity import (
get_daily_activity,
)
from litellm.proxy.management_helpers.utils import handle_budget_for_entity
from litellm.repositories.model_repository import ModelRepository
from litellm.repositories.table_repositories import (
DailyTagSpendRepository,
TagRepository,
)
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.types.tag_management import (
TagConfig,
TagDeleteRequest,
@@ -56,7 +64,7 @@ async def _get_internal_user_api_keys(
if user_id is None:
return sorted(user_api_keys)
key_records = await prisma_client.db.litellm_verificationtoken.find_many(
key_records = await VerificationTokenRepository(prisma_client).table.find_many(
where={"user_id": user_id},
select={"token": True},
)
@@ -109,7 +117,7 @@ async def _get_tag_daily_activity_api_key_filter(
async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]:
"""Helper function to get model names from model IDs"""
try:
models = await prisma_client.db.litellm_proxymodeltable.find_many(
models = await ModelRepository(prisma_client).table.find_many(
where={"model_id": {"in": model_ids}}
)
return {model.model_id: model.model_name for model in models}
@@ -189,7 +197,7 @@ async def new_tag(
)
try:
# Check if tag already exists
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
existing_tag = await TagRepository(prisma_client).table.find_unique(
where={"tag_name": tag.name}
)
if existing_tag is not None:
@@ -210,7 +218,7 @@ async def new_tag(
model_info = await _get_model_names(prisma_client, tag.models or [])
# Create new tag in database
new_tag_record = await prisma_client.db.litellm_tagtable.create(
new_tag_record = await TagRepository(prisma_client).table.create(
data={
"tag_name": tag.name,
"description": tag.description,
@@ -267,7 +275,7 @@ async def _add_tag_to_deployment(deployment: "Deployment", tag: str):
try:
# Get current model from database to preserve encrypted fields
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
db_model = await ModelRepository(prisma_client).table.find_unique(
where={"model_id": deployment.model_info.id}
)
@@ -292,7 +300,7 @@ async def _add_tag_to_deployment(deployment: "Deployment", tag: str):
existing_params["tags"].append(tag)
# Update database with modified params (keeps encrypted fields encrypted)
await prisma_client.db.litellm_proxymodeltable.update(
await ModelRepository(prisma_client).table.update(
where={"model_id": deployment.model_info.id},
data={"litellm_params": json.dumps(existing_params)},
)
@@ -335,7 +343,7 @@ async def update_tag(
try:
# Check if tag exists
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
existing_tag = await TagRepository(prisma_client).table.find_unique(
where={"tag_name": tag.name}
)
if existing_tag is None:
@@ -367,7 +375,7 @@ async def update_tag(
update_data["budget_id"] = budget_id
# Update tag in database
updated_tag_record = await prisma_client.db.litellm_tagtable.update(
updated_tag_record = await TagRepository(prisma_client).table.update(
where={"tag_name": tag.name},
data=update_data,
)
@@ -414,7 +422,7 @@ async def info_tag(
try:
# Query tags from database with budget info
tag_records = await prisma_client.db.litellm_tagtable.find_many(
tag_records = await TagRepository(prisma_client).table.find_many(
where={"tag_name": {"in": data.names}},
include={"litellm_budget_table": True},
)
@@ -535,7 +543,7 @@ async def list_tags(
if start_date is not None and end_date is not None:
dynamic_tag_where["date"] = {"gte": start_date, "lte": end_date}
dynamic_tag_rows = await prisma_client.db.litellm_dailytagspend.group_by(
dynamic_tag_rows = await DailyTagSpendRepository(prisma_client).table.group_by(
by=["tag"],
where=dynamic_tag_where,
min={"created_at": True},
@@ -551,7 +559,7 @@ async def list_tags(
)
## QUERY STORED TAGS ##
tag_records = await prisma_client.db.litellm_tagtable.find_many(
tag_records = await TagRepository(prisma_client).table.find_many(
where=stored_tag_where,
include={"litellm_budget_table": True},
)
@@ -626,14 +634,14 @@ async def delete_tag(
try:
# Check if tag exists
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
existing_tag = await TagRepository(prisma_client).table.find_unique(
where={"tag_name": data.name}
)
if existing_tag is None:
raise HTTPException(status_code=404, detail=f"Tag {data.name} not found")
# Delete tag from database
await prisma_client.db.litellm_tagtable.delete(where={"tag_name": data.name})
await TagRepository(prisma_client).table.delete(where={"tag_name": data.name})
return {"message": f"Tag {data.name} deleted successfully"}
except Exception as e:
@@ -30,6 +30,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars
from litellm.proxy.management_endpoints.team_endpoints import _verify_team_access
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.repositories.team_repository import TeamRepository
router = APIRouter()
@@ -249,7 +250,7 @@ async def add_team_callbacks(
team_metadata = encrypt_callback_vars(team_metadata)
team_metadata_json = json.dumps(team_metadata) # update team_metadata
new_team_row = await prisma_client.db.litellm_teamtable.update(
new_team_row = await TeamRepository(prisma_client).table.update(
where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore
)
@@ -353,7 +354,7 @@ async def disable_team_logging(
team_metadata_json = json.dumps(team_metadata)
# Update team in database
updated_team = await prisma_client.db.litellm_teamtable.update(
updated_team = await TeamRepository(prisma_client).table.update(
where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore
)
@@ -10,8 +10,8 @@ All /team management endpoints
"""
import asyncio
import math
import json
import math
import traceback
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple, Union, cast
@@ -102,6 +102,20 @@ from litellm.proxy.management_helpers.utils import (
management_endpoint_wrapper,
)
from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy
from litellm.repositories.budget_repository import BudgetRepository
from litellm.repositories.organization_repository import OrganizationRepository
from litellm.repositories.table_repositories import (
AccessGroupRepository,
DeletedTeamRepository,
ModelTableRepository,
OrganizationMembershipRepository,
TeamMembershipRepository,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.router import Router
from litellm.types.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
@@ -365,9 +379,9 @@ class TeamMemberBudgetHandler:
return
# Batch-fetch existing memberships for this team (avoids N+1 queries)
existing_memberships = await prisma_client.db.litellm_teammembership.find_many(
where={"team_id": team_id}
)
existing_memberships = await TeamMembershipRepository(
prisma_client
).table.find_many(where={"team_id": team_id})
existing_user_ids = {m.user_id for m in existing_memberships}
# Identify members with no existing membership row.
@@ -386,7 +400,7 @@ class TeamMemberBudgetHandler:
)
if missing:
await prisma_client.db.litellm_teammembership.create_many(
await TeamMembershipRepository(prisma_client).table.create_many(
data=missing,
skip_duplicates=True, # safety net against concurrent races
)
@@ -400,7 +414,7 @@ class TeamMemberBudgetHandler:
# Heal existing membership rows that predate the team_member_budget
# configuration: populate budget_id where it is currently NULL.
# Rows with an explicit budget_id (per-member override) are left alone.
updated = await prisma_client.db.litellm_teammembership.update_many(
updated = await TeamMembershipRepository(prisma_client).table.update_many(
where={"team_id": team_id, "budget_id": None},
data={"budget_id": team_member_budget_id},
)
@@ -456,7 +470,7 @@ async def get_all_team_memberships(
# else:
# where_obj = {"user_id": str(user_id), "team_id": {"in": team_id}}
team_memberships = await prisma_client.db.litellm_teammembership.find_many(
team_memberships = await TeamMembershipRepository(prisma_client).table.find_many(
where=where_obj,
include={"litellm_budget_table": True},
)
@@ -739,7 +753,7 @@ async def _check_org_team_limits(
# calculate allocated tpm/rpm limit
# check if specified tpm/rpm limit is greater than allocated tpm/rpm limit
teams = await prisma_client.db.litellm_teamtable.find_many(
teams = await TeamRepository(prisma_client).table.find_many(
where={"organization_id": org_table.organization_id},
)
@@ -931,6 +945,9 @@ async def new_team( # noqa: PLR0915
```
"""
try:
from litellm.proxy.management_helpers.audit_logs import (
get_audit_log_changed_by,
)
from litellm.proxy.proxy_server import (
_license_check,
create_audit_log_for_update,
@@ -938,9 +955,6 @@ async def new_team( # noqa: PLR0915
prisma_client,
user_api_key_cache,
)
from litellm.proxy.management_helpers.audit_logs import (
get_audit_log_changed_by,
)
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
@@ -986,7 +1000,7 @@ async def new_team( # noqa: PLR0915
)
# Check if license is over limit
total_teams = await prisma_client.db.litellm_teamtable.count()
total_teams = await TeamRepository(prisma_client).table.count()
if total_teams and _license_check.is_team_count_over_limit(
team_count=total_teams
):
@@ -1092,7 +1106,7 @@ async def new_team( # noqa: PLR0915
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
model_dict = await prisma_client.db.litellm_modeltable.create(
model_dict = await ModelTableRepository(prisma_client).table.create(
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
) # type: ignore
@@ -1195,7 +1209,7 @@ async def new_team( # noqa: PLR0915
db_data=complete_team_data_dict
)
team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create(
team_row: LiteLLM_TeamTable = await TeamRepository(prisma_client).table.create(
data=complete_team_data_dict,
include={"litellm_model_table": True}, # type: ignore
)
@@ -1315,11 +1329,11 @@ async def _update_model_table(
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
if model_id is None:
model_dict = await prisma_client.db.litellm_modeltable.create(
model_dict = await ModelTableRepository(prisma_client).table.create(
data={**litellm_modeltable.json(exclude_none=True)} # type: ignore
)
else:
model_dict = await prisma_client.db.litellm_modeltable.upsert(
model_dict = await ModelTableRepository(prisma_client).table.upsert(
where={"id": model_id},
data={
"update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
@@ -1400,7 +1414,7 @@ async def fetch_and_validate_organization(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
organization_row = await prisma_client.db.litellm_organizationtable.find_unique(
organization_row = await OrganizationRepository(prisma_client).table.find_unique(
where={"organization_id": organization_id},
include={"litellm_budget_table": True, "members": True, "teams": True},
)
@@ -1669,7 +1683,7 @@ async def update_team( # noqa: PLR0915
},
)
existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
existing_team_row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": data.team_id}
)
@@ -1738,7 +1752,9 @@ async def update_team( # noqa: PLR0915
):
# Is the caller org_admin of the destination org?
caller_memberships = (
await prisma_client.db.litellm_organizationmembership.find_many(
await OrganizationMembershipRepository(
prisma_client
).table.find_many(
where={
"user_id": user_api_key_dict.user_id,
"organization_id": data.organization_id,
@@ -1885,18 +1901,18 @@ async def update_team( # noqa: PLR0915
updated_kv["router_settings"] = safe_dumps(updated_kv["router_settings"])
updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv)
team_row: Optional[LiteLLM_TeamTable] = (
await prisma_client.db.litellm_teamtable.update(
where={"team_id": data.team_id},
data=updated_kv,
# `object_permission` is included so `_refresh_cached_team`
# doesn't write a cached team with the relation nulled out —
# see team_model_add for the full rationale.
include={
"litellm_model_table": True,
"object_permission": True,
}, # type: ignore
)
team_row: Optional[LiteLLM_TeamTable] = await TeamRepository(
prisma_client
).table.update(
where={"team_id": data.team_id},
data=updated_kv,
# `object_permission` is included so `_refresh_cached_team`
# doesn't write a cached team with the relation nulled out —
# see team_model_add for the full rationale.
include={
"litellm_model_table": True,
"object_permission": True,
}, # type: ignore
)
if team_row is None or team_row.team_id is None:
@@ -2306,7 +2322,7 @@ async def _add_team_members_to_team(
# ADD MEMBER TO TEAM
_db_team_members = [m.model_dump() for m in complete_team_data.members_with_roles]
updated_team = await prisma_client.db.litellm_teamtable.update(
updated_team = await TeamRepository(prisma_client).table.update(
where={"team_id": data.team_id},
data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore
)
@@ -2377,7 +2393,7 @@ async def _validate_and_populate_member_user_info(
# Case 2: Only user_email provided - populate user_id from DB
if member.user_email is not None and member.user_id is None:
user_by_email = await prisma_client.db.litellm_usertable.find_first(
user_by_email = await UserRepository(prisma_client).table.find_first(
where={"user_email": {"equals": member.user_email, "mode": "insensitive"}}
)
@@ -2410,7 +2426,7 @@ async def _validate_and_populate_member_user_info(
# Case 3: Only user_id provided - populate user_email from DB if user exists
if member.user_id is not None and member.user_email is None:
user_by_id = await prisma_client.db.litellm_usertable.find_unique(
user_by_id = await UserRepository(prisma_client).table.find_unique(
where={"user_id": member.user_id}
)
@@ -2608,7 +2624,7 @@ async def team_member_delete(
detail={"error": "Either user_id or user_email needs to be passed in"},
)
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
_existing_team_row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": data.team_id}
)
@@ -2652,7 +2668,7 @@ async def team_member_delete(
_db_new_team_members: List[dict] = [m.model_dump() for m in new_team_members]
_ = await prisma_client.db.litellm_teamtable.update(
_ = await TeamRepository(prisma_client).table.update(
where={
"team_id": data.team_id,
},
@@ -2666,7 +2682,7 @@ async def team_member_delete(
key_val["user_id"] = data.user_id
elif data.user_email is not None:
key_val["user_email"] = data.user_email
existing_user_rows = await prisma_client.db.litellm_usertable.find_many(
existing_user_rows = await UserRepository(prisma_client).table.find_many(
where=key_val # type: ignore
)
@@ -2678,7 +2694,7 @@ async def team_member_delete(
if data.team_id in existing_user.teams:
team_list = existing_user.teams
team_list.remove(data.team_id)
await prisma_client.db.litellm_usertable.update(
await UserRepository(prisma_client).table.update(
where={
"user_id": existing_user.user_id,
},
@@ -2695,7 +2711,7 @@ async def team_member_delete(
user_ids_to_delete.add(existing_user.user_id)
for _uid in user_ids_to_delete:
await prisma_client.db.litellm_teammembership.delete_many(
await TeamMembershipRepository(prisma_client).table.delete_many(
where={"team_id": data.team_id, "user_id": _uid}
)
@@ -2706,13 +2722,13 @@ async def team_member_delete(
)
# Fetch keys before deletion to persist them
keys_to_delete: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={
"user_id": {"in": list(user_ids_to_delete)},
"team_id": data.team_id,
}
)
keys_to_delete: List[
LiteLLM_VerificationToken
] = await VerificationTokenRepository(prisma_client).table.find_many(
where={
"user_id": {"in": list(user_ids_to_delete)},
"team_id": data.team_id,
}
)
if keys_to_delete:
@@ -2723,7 +2739,7 @@ async def team_member_delete(
litellm_changed_by=None,
)
await prisma_client.db.litellm_verificationtoken.delete_many(
await VerificationTokenRepository(prisma_client).table.delete_many(
where={
"user_id": {"in": list(user_ids_to_delete)},
"team_id": data.team_id,
@@ -2818,7 +2834,7 @@ async def team_member_update(
_validate_budget_duration(data.budget_duration)
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
_existing_team_row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": data.team_id}
)
@@ -2921,7 +2937,7 @@ async def team_member_update(
team_table.members_with_roles = team_members
_db_team_members: List[dict] = [m.model_dump() for m in team_members]
await prisma_client.db.litellm_teamtable.update(
await TeamRepository(prisma_client).table.update(
where={"team_id": data.team_id},
data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore
)
@@ -3052,7 +3068,7 @@ async def bulk_team_member_add(
},
)
# get all users from the database
all_users_in_db = await prisma_client.db.litellm_usertable.find_many(
all_users_in_db = await UserRepository(prisma_client).table.find_many(
order={"created_at": "desc"}
)
data.members = [
@@ -3153,14 +3169,14 @@ async def delete_team(
}'
```
"""
from litellm.proxy.management_helpers.audit_logs import (
get_audit_log_changed_by,
)
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
litellm_proxy_admin_name,
prisma_client,
)
from litellm.proxy.management_helpers.audit_logs import (
get_audit_log_changed_by,
)
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
@@ -3172,11 +3188,9 @@ async def delete_team(
team_rows: List[LiteLLM_TeamTable] = []
for team_id in data.team_ids:
try:
team_row_base: Optional[BaseModel] = (
await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
)
team_row_base: Optional[BaseModel] = await TeamRepository(
prisma_client
).table.find_unique(where={"team_id": team_id})
if team_row_base is None:
raise Exception
except Exception:
@@ -3243,11 +3257,9 @@ async def delete_team(
_persist_deleted_verification_tokens,
)
keys_to_delete: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"team_id": {"in": data.team_ids}}
)
)
keys_to_delete: List[LiteLLM_VerificationToken] = await VerificationTokenRepository(
prisma_client
).table.find_many(where={"team_id": {"in": data.team_ids}})
if keys_to_delete:
await _persist_deleted_verification_tokens(
@@ -3338,7 +3350,7 @@ async def _save_deleted_team_records(
"""Save deleted team records to the database."""
if not records:
return
await prisma_client.db.litellm_deletedteamtable.create_many(data=records)
await DeletedTeamRepository(prisma_client).table.create_many(data=records)
async def _persist_deleted_team_records(
@@ -3420,7 +3432,7 @@ async def _add_team_member_budget_table(
team_info_response_object: TeamInfoResponseObjectTeamTable,
) -> TeamInfoResponseObjectTeamTable:
try:
team_budget = await prisma_client.db.litellm_budgettable.find_unique(
team_budget = await BudgetRepository(prisma_client).table.find_unique(
where={"budget_id": team_member_budget_id}
)
team_info_response_object.team_member_budget_table = team_budget
@@ -3489,11 +3501,11 @@ async def team_info(
)
try:
team_info: Optional[BaseModel] = (
await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id},
include={"object_permission": True},
)
team_info: Optional[BaseModel] = await TeamRepository(
prisma_client
).table.find_unique(
where={"team_id": team_id},
include={"object_permission": True},
)
if team_info is None:
raise Exception
@@ -3749,7 +3761,7 @@ async def block_team(
if prisma_client is None:
raise Exception("No DB Connected.")
existing_team = await prisma_client.db.litellm_teamtable.find_unique(
existing_team = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": data.team_id}
)
if existing_team is None:
@@ -3764,7 +3776,7 @@ async def block_team(
user_api_key_dict=user_api_key_dict,
)
record = await prisma_client.db.litellm_teamtable.update(
record = await TeamRepository(prisma_client).table.update(
where={"team_id": data.team_id}, data={"blocked": True} # type: ignore
)
@@ -3801,7 +3813,7 @@ async def unblock_team(
if prisma_client is None:
raise Exception("No DB Connected.")
existing_team = await prisma_client.db.litellm_teamtable.find_unique(
existing_team = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": data.team_id}
)
if existing_team is None:
@@ -3816,7 +3828,7 @@ async def unblock_team(
user_api_key_dict=user_api_key_dict,
)
record = await prisma_client.db.litellm_teamtable.update(
record = await TeamRepository(prisma_client).table.update(
where={"team_id": data.team_id}, data={"blocked": False} # type: ignore
)
@@ -3849,7 +3861,7 @@ async def list_available_teams(
return []
# filter out teams that the user is already a member of
user_info = await prisma_client.db.litellm_usertable.find_unique(
user_info = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_api_key_dict.user_id}
)
if user_info is None:
@@ -3863,7 +3875,7 @@ async def list_available_teams(
team for team in available_teams if team not in user_info_correct_type.teams
]
available_teams_db = await prisma_client.db.litellm_teamtable.find_many(
available_teams_db = await TeamRepository(prisma_client).table.find_many(
where={"team_id": {"in": available_teams}}
)
@@ -4009,7 +4021,7 @@ async def _batch_resolve_access_group_resources(
return {}
unique_ids = list(set(all_access_group_ids))
rows = await _prisma_client.db.litellm_accessgrouptable.find_many(
rows = await AccessGroupRepository(_prisma_client).table.find_many(
where={"access_group_id": {"in": unique_ids}},
)
@@ -4074,7 +4086,7 @@ async def _get_keys_count_by_team(
if not page_team_ids:
return {}
grouped = await prisma_client.db.litellm_verificationtoken.group_by(
grouped = await VerificationTokenRepository(prisma_client).table.group_by(
by=["team_id"],
where={"team_id": {"in": page_team_ids}},
count={"team_id": True},
@@ -4288,25 +4300,25 @@ async def list_team_v2(
# Get teams with pagination
if use_deleted_table:
teams = await prisma_client.db.litellm_deletedteamtable.find_many(
teams = await DeletedTeamRepository(prisma_client).table.find_many(
where=where_conditions,
skip=skip,
take=page_size,
order=order_by if order_by else {"created_at": "desc"}, # Default sort
)
# Get total count for pagination
total_count = await prisma_client.db.litellm_deletedteamtable.count(
total_count = await DeletedTeamRepository(prisma_client).table.count(
where=where_conditions
)
else:
teams = await prisma_client.db.litellm_teamtable.find_many(
teams = await TeamRepository(prisma_client).table.find_many(
where=where_conditions,
skip=skip,
take=page_size,
order=order_by if order_by else {"created_at": "desc"}, # Default sort
)
# Get total count for pagination
total_count = await prisma_client.db.litellm_teamtable.count(
total_count = await TeamRepository(prisma_client).table.count(
where=where_conditions
)
@@ -4412,7 +4424,7 @@ async def _authorize_and_filter_teams(
if allowed_org_ids is not None:
# Org admin: query DB for teams in their orgs
org_teams = await prisma_client.db.litellm_teamtable.find_many(
org_teams = await TeamRepository(prisma_client).table.find_many(
where={"organization_id": {"in": allowed_org_ids}},
include={"litellm_model_table": True},
)
@@ -4427,7 +4439,7 @@ async def _authorize_and_filter_teams(
]
elif user_id:
# Regular user: fetch all and filter by membership (Prisma can't filter JSON arrays)
response = await prisma_client.db.litellm_teamtable.find_many(
response = await TeamRepository(prisma_client).table.find_many(
include={"litellm_model_table": True}
)
return [
@@ -4439,7 +4451,7 @@ async def _authorize_and_filter_teams(
else:
# Proxy admin: all teams
return list(
await prisma_client.db.litellm_teamtable.find_many(
await TeamRepository(prisma_client).table.find_many(
include={"litellm_model_table": True}
)
)
@@ -4500,7 +4512,7 @@ async def list_team(
_team_memberships.append(tm)
# add all keys that belong to the team
keys = await prisma_client.db.litellm_verificationtoken.find_many(
keys = await VerificationTokenRepository(prisma_client).table.find_many(
where={"team_id": team.team_id}
)
@@ -4556,10 +4568,10 @@ async def get_paginated_teams(
# Calculate skip for pagination
skip = (page - 1) * page_size
# Get total count
total_count = await prisma_client.db.litellm_teamtable.count()
total_count = await TeamRepository(prisma_client).table.count()
# Get paginated teams
teams = await prisma_client.db.litellm_teamtable.find_many(
teams = await TeamRepository(prisma_client).table.find_many(
skip=skip, take=page_size, order={"team_alias": "asc"} # Sort by team_alias
)
return teams, total_count
@@ -4632,7 +4644,7 @@ async def ui_view_teams(
}
# Query users with pagination and filters
teams = await prisma_client.db.litellm_teamtable.find_many(
teams = await TeamRepository(prisma_client).table.find_many(
where=where_conditions,
skip=skip,
take=page_size,
@@ -4704,7 +4716,7 @@ async def team_model_add(
raise HTTPException(status_code=500, detail={"error": "No db connected"})
# Get existing team
team_row = await prisma_client.db.litellm_teamtable.find_unique(
team_row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": data.team_id}
)
@@ -4737,7 +4749,7 @@ async def team_model_add(
# null them out — see object_permission_utils.validate_key_search_tools_against_team
# and the MCP/agent authz paths, which treat a missing object_permission
# as "no team-level restriction".
updated_team = await prisma_client.db.litellm_teamtable.update(
updated_team = await TeamRepository(prisma_client).table.update(
where={"team_id": data.team_id},
data={"models": updated_models},
include={"object_permission": True}, # type: ignore
@@ -4791,7 +4803,7 @@ async def team_model_delete(
raise HTTPException(status_code=500, detail={"error": "No db connected"})
# Get existing team
team_row = await prisma_client.db.litellm_teamtable.find_unique(
team_row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": data.team_id}
)
@@ -4825,7 +4837,7 @@ async def team_model_delete(
updated_models = [m for m in current_models if m not in data.models]
# Update team. See team_model_add for the rationale on `include`.
updated_team = await prisma_client.db.litellm_teamtable.update(
updated_team = await TeamRepository(prisma_client).table.update(
where={"team_id": data.team_id},
data={"models": updated_models},
include={"object_permission": True}, # type: ignore
@@ -4972,7 +4984,7 @@ async def update_team_member_permissions(
},
)
# Update the team member permissions
updated_team = await prisma_client.db.litellm_teamtable.update(
updated_team = await TeamRepository(prisma_client).table.update(
where={"team_id": data.team_id},
data={"team_member_permissions": data.team_member_permissions},
)
@@ -5076,7 +5088,7 @@ async def _append_permissions_to_specific_teams(
prisma_client, team_ids: List[str], permissions_to_add: set
) -> int:
"""Fetch specific teams by ID and append permissions."""
teams = await prisma_client.db.litellm_teamtable.find_many(
teams = await TeamRepository(prisma_client).table.find_many(
where={"team_id": {"in": team_ids}},
)
@@ -5108,7 +5120,7 @@ async def _append_permissions_to_all_teams(
find_args["cursor"] = {"team_id": cursor}
find_args["skip"] = 1
teams = await prisma_client.db.litellm_teamtable.find_many(**find_args)
teams = await TeamRepository(prisma_client).table.find_many(**find_args)
if not teams:
break
@@ -5214,7 +5226,7 @@ async def get_team_daily_activity(
where_condition = {}
if team_ids_list:
where_condition["team_id"] = {"in": list(team_ids_list)}
team_aliases = await prisma_client.db.litellm_teamtable.find_many(
team_aliases = await TeamRepository(prisma_client).table.find_many(
where=where_condition
)
team_alias_metadata = {
@@ -5251,9 +5263,9 @@ async def get_team_daily_activity(
# If user does not have full team view, filter by their API keys
if not has_full_team_view:
# Get all API keys for this user
user_keys = await prisma_client.db.litellm_verificationtoken.find_many(
where={"user_id": user_api_key_dict.user_id}
)
user_keys = await VerificationTokenRepository(
prisma_client
).table.find_many(where={"user_id": user_api_key_dict.user_id})
user_api_keys = [key.token for key in user_keys if key.token]
# If user has no API keys, return empty result
if not user_api_keys:
@@ -21,6 +21,15 @@ if TYPE_CHECKING:
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
from litellm.repositories.table_repositories import (
SpendLogsRepository,
SpendLogToolIndexRepository,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.types.tool_management import (
LiteLLM_ToolTableRow,
ToolDetailResponse,
@@ -256,8 +265,10 @@ async def get_tool_usage_logs(
if end_time_filter is not None:
where["start_time"]["lte"] = end_time_filter
total = await prisma_client.db.litellm_spendlogtoolindex.count(where=where)
index_rows = await prisma_client.db.litellm_spendlogtoolindex.find_many(
total = await SpendLogToolIndexRepository(prisma_client).table.count(
where=where
)
index_rows = await SpendLogToolIndexRepository(prisma_client).table.find_many(
where=where,
order={"start_time": "desc"},
skip=(page - 1) * page_size,
@@ -269,7 +280,7 @@ async def get_tool_usage_logs(
logs=[], total=total, page=page, page_size=page_size
)
spend_logs = await prisma_client.db.litellm_spendlogs.find_many(
spend_logs = await SpendLogsRepository(prisma_client).table.find_many(
where={"request_id": {"in": request_ids}}
)
log_by_id = {s.request_id: s for s in spend_logs}
@@ -348,7 +359,7 @@ async def _resolve_key_hash_to_object_permission_id(
hashed = key_hash if "sk-" not in (key_hash or "") else hash_token(key_hash)
if not hashed:
return None
row = await prisma_client.db.litellm_verificationtoken.find_unique(
row = await VerificationTokenRepository(prisma_client).table.find_unique(
where={"token": hashed}
)
if row is None:
@@ -357,18 +368,18 @@ async def _resolve_key_hash_to_object_permission_id(
if op_id:
return op_id
new_id = str(uuid.uuid4())
await prisma_client.db.litellm_objectpermissiontable.create(
await ObjectPermissionRepository(prisma_client).table.create(
data={"object_permission_id": new_id, "blocked_tools": []}
)
updated_count = await prisma_client.db.litellm_verificationtoken.update_many(
updated_count = await VerificationTokenRepository(prisma_client).table.update_many(
where={"token": hashed, "object_permission_id": None},
data={"object_permission_id": new_id},
)
if updated_count == 0:
await prisma_client.db.litellm_objectpermissiontable.delete(
await ObjectPermissionRepository(prisma_client).table.delete(
where={"object_permission_id": new_id}
)
row = await prisma_client.db.litellm_verificationtoken.find_unique(
row = await VerificationTokenRepository(prisma_client).table.find_unique(
where={"token": hashed}
)
return getattr(row, "object_permission_id", None) if row else None
@@ -383,7 +394,7 @@ async def _resolve_team_id_to_object_permission_id(
if not team_id or not team_id.strip():
return None
team_id_clean = team_id.strip()
row = await prisma_client.db.litellm_teamtable.find_unique(
row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id_clean},
select={"object_permission_id": True},
)
@@ -393,18 +404,18 @@ async def _resolve_team_id_to_object_permission_id(
if op_id:
return op_id
new_id = str(uuid.uuid4())
await prisma_client.db.litellm_objectpermissiontable.create(
await ObjectPermissionRepository(prisma_client).table.create(
data={"object_permission_id": new_id, "blocked_tools": []}
)
updated_count = await prisma_client.db.litellm_teamtable.update_many(
updated_count = await TeamRepository(prisma_client).table.update_many(
where={"team_id": team_id_clean, "object_permission_id": None},
data={"object_permission_id": new_id},
)
if updated_count == 0:
await prisma_client.db.litellm_objectpermissiontable.delete(
await ObjectPermissionRepository(prisma_client).table.delete(
where={"object_permission_id": new_id}
)
row = await prisma_client.db.litellm_teamtable.find_unique(
row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id_clean},
select={"object_permission_id": True},
)
+16 -13
View File
@@ -15,8 +15,8 @@ import inspect
import os
import re
import secrets
from html import escape
from copy import deepcopy
from html import escape
from typing import (
TYPE_CHECKING,
Any,
@@ -39,9 +39,9 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from fastapi.responses import RedirectResponse
import litellm
from litellm.caching.dual_cache import DualCache
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.caching.dual_cache import DualCache
from litellm.constants import (
CLI_SSO_CLAIM_MAP,
CLI_SSO_CLAIM_MAX_SCALAR_LENGTH,
@@ -77,7 +77,6 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken, get_user_object
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.auth.auth_utils import (
_get_request_ip_address,
_has_user_setup_sso,
@@ -92,6 +91,7 @@ from litellm.proxy.common_utils.html_forms.jwt_display_template import (
jwt_display_template,
)
from litellm.proxy.common_utils.html_forms.ui_login import html_form
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
from litellm.proxy.management_endpoints.sso import CustomMicrosoftSSO
from litellm.proxy.management_endpoints.sso_helper_utils import (
@@ -110,6 +110,9 @@ from litellm.proxy.utils import (
get_custom_url,
get_server_root_path,
)
from litellm.repositories.table_repositories import SSOConfigRepository
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
from litellm.secret_managers.main import get_secret_bool, str_to_bool
from litellm.types.proxy.management_endpoints.ui_sso import * # noqa: F403, F401
from litellm.types.proxy.management_endpoints.ui_sso import (
@@ -438,7 +441,7 @@ async def _persist_cli_sso_user_metadata(
return
try:
user_row = await prisma_client.db.litellm_usertable.find_unique(
user_row = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_id}
)
existing_metadata: Dict[str, Any] = {}
@@ -451,7 +454,7 @@ async def _persist_cli_sso_user_metadata(
existing_metadata=existing_metadata,
attribution_metadata=attribution_metadata,
)
await prisma_client.db.litellm_usertable.update_many(
await UserRepository(prisma_client).table.update_many(
where={"user_id": user_id},
data={"metadata": merged_metadata},
)
@@ -859,7 +862,7 @@ async def google_login(
if premium_user is not True:
# Check if under 'free SSO user' limit
if prisma_client is not None:
total_users = await prisma_client.db.litellm_usertable.count()
total_users = await UserRepository(prisma_client).table.count()
if total_users and total_users > 5:
raise ProxyException(
message="You must be a LiteLLM Enterprise user to use SSO for more than 5 users. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://enterprise.litellm.ai/demo You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
@@ -1150,7 +1153,7 @@ async def _setup_team_mappings() -> Optional["TeamMappings"]:
"Prisma client is None, connect a database to your proxy"
)
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
sso_db_record = await SSOConfigRepository(prisma_client).table.find_unique(
where={"id": "sso_config"}
)
@@ -1188,7 +1191,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]:
"Prisma client is None, connect a database to your proxy"
)
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
sso_db_record = await SSOConfigRepository(prisma_client).table.find_unique(
where={"id": "sso_config"}
)
@@ -1755,7 +1758,7 @@ async def _sync_user_role_from_jwt_role_map(
# Update existing DB record if role differs
if user_info is not None and user_info.user_role != mapped_role.value:
await prisma_client.db.litellm_usertable.update(
await UserRepository(prisma_client).table.update(
where={"user_id": user_info.user_id},
data={"user_role": mapped_role.value},
)
@@ -1819,7 +1822,7 @@ async def check_and_update_if_proxy_admin_id(
return user_role
if prisma_client:
await prisma_client.db.litellm_usertable.update(
await UserRepository(prisma_client).table.update(
where={"user_id": user_id},
data={"user_role": LitellmUserRoles.PROXY_ADMIN.value},
)
@@ -1976,7 +1979,7 @@ async def _fetch_cli_sso_team_details(
team_details: List[Dict[str, Any]] = []
try:
if teams:
prisma_teams = await prisma_client.db.litellm_teamtable.find_many(
prisma_teams = await TeamRepository(prisma_client).table.find_many(
where={"team_id": {"in": teams}}
)
for team_row in prisma_teams:
@@ -2884,7 +2887,7 @@ class SSOAuthenticationHandler:
user_id=user_id,
)
await prisma_client.db.litellm_usertable.update_many(
await UserRepository(prisma_client).table.update_many(
where={"user_id": user_id}, data=update_data
)
else:
@@ -2986,7 +2989,7 @@ class SSOAuthenticationHandler:
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
try:
team_obj = await prisma_client.db.litellm_teamtable.find_first(
team_obj = await TeamRepository(prisma_client).table.find_first(
where={"team_id": litellm_team_id}
)
verbose_proxy_logger.debug(f"Team object: {team_obj}")
@@ -19,6 +19,11 @@ from pydantic import BaseModel
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.repositories.table_repositories import DailyTagSpendRepository
from litellm.repositories.user_repository import UserRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
# Constants for analytics periods
MAX_DAYS = 7 # Number of days to show in DAU analytics
@@ -676,7 +681,7 @@ async def get_per_user_analytics(
where_clause["tag"] = {"contains": tag_filter}
# Get all tag records in the date range with optional tag filtering
tag_records = await prisma_client.db.litellm_dailytagspend.find_many(
tag_records = await DailyTagSpendRepository(prisma_client).table.find_many(
where=where_clause
)
@@ -693,9 +698,9 @@ async def get_per_user_analytics(
)
# Lookup user_id for each api_key
api_key_records = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": list(api_keys)}}
)
api_key_records = await VerificationTokenRepository(
prisma_client
).table.find_many(where={"token": {"in": list(api_keys)}})
# Create mapping from api_key to user_id
api_key_to_user_id = {
@@ -704,7 +709,7 @@ async def get_per_user_analytics(
# Get user emails for the user_ids
user_ids = list(set(api_key_to_user_id.values()))
user_records = await prisma_client.db.litellm_usertable.find_many(
user_records = await UserRepository(prisma_client).table.find_many(
where={"user_id": {"in": user_ids}}
)
@@ -27,6 +27,11 @@ from pydantic import BaseModel
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.repositories.table_repositories import (
WorkflowEventRepository,
WorkflowMessageRepository,
WorkflowRunRepository,
)
router = APIRouter()
@@ -96,13 +101,13 @@ class WorkflowMessageCreateRequest(BaseModel):
async def _get_next_sequence_number(prisma_client: Any, run_id: str, table: str) -> int:
"""Return MAX(sequence_number) + 1 for the given run, for either events or messages."""
if table == "events":
rows = await prisma_client.db.litellm_workflowevent.find_many(
rows = await WorkflowEventRepository(prisma_client).table.find_many(
where={"run_id": run_id},
order={"sequence_number": "desc"},
take=1,
)
else:
rows = await prisma_client.db.litellm_workflowmessage.find_many(
rows = await WorkflowMessageRepository(prisma_client).table.find_many(
where={"run_id": run_id},
order={"sequence_number": "desc"},
take=1,
@@ -116,7 +121,7 @@ async def _require_run(
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
) -> Any:
"""Return the run or raise 404. For non-admin callers, also enforce key ownership."""
run = await prisma_client.db.litellm_workflowrun.find_unique(
run = await WorkflowRunRepository(prisma_client).table.find_unique(
where={"run_id": run_id}
)
if run is None:
@@ -163,7 +168,7 @@ async def create_workflow_run(
create_data["input"] = _json(data.input)
if data.metadata is not None:
create_data["metadata"] = _json(data.metadata)
run = await prisma_client.db.litellm_workflowrun.create(data=create_data)
run = await WorkflowRunRepository(prisma_client).table.create(data=create_data)
return run
except Exception as e:
verbose_proxy_logger.exception("Error creating workflow run: %s", e)
@@ -206,7 +211,7 @@ async def list_workflow_runs(
where["created_by"] = caller
try:
runs = await prisma_client.db.litellm_workflowrun.find_many(
runs = await WorkflowRunRepository(prisma_client).table.find_many(
where=where,
order={"created_at": "desc"},
take=limit,
@@ -235,7 +240,7 @@ async def get_workflow_run(
)
try:
run = await prisma_client.db.litellm_workflowrun.find_unique(
run = await WorkflowRunRepository(prisma_client).table.find_unique(
where={"run_id": run_id},
include={"events": {"order_by": {"sequence_number": "desc"}, "take": 1}},
)
@@ -286,7 +291,7 @@ async def update_workflow_run(
await _require_run(prisma_client, run_id, user_api_key_dict)
try:
run = await prisma_client.db.litellm_workflowrun.update(
run = await WorkflowRunRepository(prisma_client).table.update(
where={"run_id": run_id},
data=update,
)
@@ -391,7 +396,7 @@ async def list_workflow_events(
await _require_run(prisma_client, run_id, user_api_key_dict)
try:
events = await prisma_client.db.litellm_workflowevent.find_many(
events = await WorkflowEventRepository(prisma_client).table.find_many(
where={"run_id": run_id},
order={"sequence_number": "asc"},
take=limit,
@@ -436,7 +441,9 @@ async def append_workflow_message(
}
if data.session_id is not None:
msg_data["session_id"] = data.session_id
msg = await prisma_client.db.litellm_workflowmessage.create(data=msg_data)
msg = await WorkflowMessageRepository(prisma_client).table.create(
data=msg_data
)
return msg
except Exception as e:
@@ -481,7 +488,7 @@ async def list_workflow_messages(
await _require_run(prisma_client, run_id, user_api_key_dict)
try:
messages = await prisma_client.db.litellm_workflowmessage.find_many(
messages = await WorkflowMessageRepository(prisma_client).table.find_many(
where={"run_id": run_id},
order={"sequence_number": "asc"},
take=limit,
@@ -18,6 +18,7 @@ from litellm.proxy._types import (
Optional,
UserAPIKeyAuth,
)
from litellm.repositories.table_repositories import AuditLogRepository
from litellm.types.utils import StandardAuditLogPayload
_audit_log_callback_cache: Dict[str, CustomLogger] = {}
@@ -244,7 +245,7 @@ async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs):
_request_data = request_data.model_dump(exclude_none=True)
try:
await prisma_client.db.litellm_auditlog.create(
await AuditLogRepository(prisma_client).table.create(
data={
**_request_data, # type: ignore
}
@@ -12,6 +12,8 @@ from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy.utils import PrismaClient
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
from litellm.repositories.table_repositories import MCPServerRepository
if TYPE_CHECKING:
from litellm.proxy._types import (
@@ -48,10 +50,10 @@ async def attach_object_permission_to_dict(
object_permission_id = data_dict.get("object_permission_id")
if object_permission_id:
object_permission = (
await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": object_permission_id},
)
object_permission = await ObjectPermissionRepository(
prisma_client
).table.find_unique(
where={"object_permission_id": object_permission_id},
)
if object_permission:
# Convert to dict if needed
@@ -106,10 +108,10 @@ async def handle_update_object_permission_common(
)
existing_object_permissions_dict: Dict = {}
existing_object_permission = (
await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": object_permission_id_to_use},
)
existing_object_permission = await ObjectPermissionRepository(
prisma_client
).table.find_unique(
where={"object_permission_id": object_permission_id_to_use},
)
# Update the object permission
@@ -137,14 +139,14 @@ async def handle_update_object_permission_common(
#########################################################
# Commit the update to the LiteLLM_ObjectPermissionTable
#########################################################
created_object_permission_row = (
await prisma_client.db.litellm_objectpermissiontable.upsert(
where={"object_permission_id": object_permission_id_to_use},
data={
"create": existing_object_permissions_dict,
"update": existing_object_permissions_dict,
},
)
created_object_permission_row = await ObjectPermissionRepository(
prisma_client
).table.upsert(
where={"object_permission_id": object_permission_id_to_use},
data={
"create": existing_object_permissions_dict,
"update": existing_object_permissions_dict,
},
)
verbose_proxy_logger.debug(
@@ -183,7 +185,7 @@ async def _set_object_permission(
clean_data["mcp_tool_permissions"]
)
created_permission = await prisma_client.db.litellm_objectpermissiontable.create(
created_permission = await ObjectPermissionRepository(prisma_client).table.create(
data=clean_data
)
@@ -220,7 +222,7 @@ async def _get_db_mcp_servers_by_identifiers(
return []
identifier_list = list(identifiers)
return await prisma_client.db.litellm_mcpservertable.find_many(
return await MCPServerRepository(prisma_client).table.find_many(
where={
"OR": [
{"server_id": {"in": identifier_list}},
@@ -4,6 +4,7 @@ from fastapi import HTTPException
import litellm
from litellm.proxy._types import CommonProxyErrors, InvitationNew, UserAPIKeyAuth
from litellm.repositories.table_repositories import InvitationLinkRepository
async def create_invitation_for_user(
@@ -25,7 +26,7 @@ async def create_invitation_for_user(
expires_at = current_time + timedelta(days=7)
try:
response = await prisma_client.db.litellm_invitationlink.create(
response = await InvitationLinkRepository(prisma_client).table.create(
data={
"user_id": data.user_id,
"created_at": current_time,
+20 -17
View File
@@ -9,9 +9,8 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.otel.model.config import is_otel_v2_enabled
from litellm._uuid import uuid
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
from litellm.integrations.otel.model.config import is_otel_v2_enabled
from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types
BudgetNewRequest,
DeleteCustomerRequest,
@@ -32,7 +31,11 @@ from litellm.proxy._types import ( # key request types; user request types; tea
VirtualKeyEvent,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
from litellm.proxy.utils import PrismaClient
from litellm.repositories.budget_repository import BudgetRepository
from litellm.repositories.table_repositories import TeamMembershipRepository
from litellm.repositories.user_repository import UserRepository
def get_new_internal_user_defaults(
@@ -111,7 +114,7 @@ async def handle_budget_for_entity(
budget_row.model_dump(exclude_none=True)
)
_budget = await prisma_client.db.litellm_budgettable.create(
_budget = await BudgetRepository(prisma_client).table.create(
data={
**new_budget_data, # type: ignore
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
@@ -174,7 +177,7 @@ async def _clone_team_default_budget_for_member(
so the member starts with the team default's values but gets their own
private budget row (which can be edited independently).
"""
default_budget = await prisma_client.db.litellm_budgettable.find_unique(
default_budget = await BudgetRepository(prisma_client).table.find_unique(
where={"budget_id": default_team_budget_id}
)
if default_budget is None:
@@ -202,7 +205,7 @@ async def _clone_team_default_budget_for_member(
cloned_data["budget_duration"]
)
new_budget = await prisma_client.db.litellm_budgettable.create(data=cloned_data)
new_budget = await BudgetRepository(prisma_client).table.create(data=cloned_data)
return new_budget.budget_id
@@ -229,7 +232,7 @@ async def add_new_member(
## ADD TEAM ID, to USER TABLE IF NEW ##
if new_member.user_id is not None:
new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id)
_returned_user = await prisma_client.db.litellm_usertable.upsert(
_returned_user = await UserRepository(prisma_client).table.upsert(
where={"user_id": new_member.user_id},
data={
"update": {"teams": {"push": [team_id]}},
@@ -259,7 +262,7 @@ async def add_new_member(
returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
elif len(existing_user_row) == 1:
user_info = existing_user_row[0]
_returned_user = await prisma_client.db.litellm_usertable.update(
_returned_user = await UserRepository(prisma_client).table.update(
where={"user_id": user_info.user_id}, # type: ignore
data={"teams": {"push": [team_id]}},
)
@@ -284,7 +287,7 @@ async def add_new_member(
budget_data["max_budget"] = max_budget_in_team
if allowed_models is not None:
budget_data["allowed_models"] = allowed_models
response = await prisma_client.db.litellm_budgettable.create(data=budget_data)
response = await BudgetRepository(prisma_client).table.create(data=budget_data)
_budget_id = response.budget_id
elif default_team_budget_id is not None:
@@ -303,15 +306,15 @@ async def add_new_member(
_budget_id = None
if _budget_id and returned_user is not None and returned_user.user_id is not None:
_returned_team_membership = (
await prisma_client.db.litellm_teammembership.create(
data={
"team_id": team_id,
"user_id": returned_user.user_id,
"budget_id": _budget_id,
},
include={"litellm_budget_table": True},
)
_returned_team_membership = await TeamMembershipRepository(
prisma_client
).table.create(
data={
"team_id": team_id,
"user_id": returned_user.user_id,
"budget_id": _budget_id,
},
include={"litellm_budget_table": True},
)
returned_team_membership = LiteLLM_TeamMembership(
+11 -9
View File
@@ -29,6 +29,8 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.repositories.table_repositories import MemoryRepository
from litellm.repositories.team_repository import TeamRepository
from litellm.types.memory_management import (
LiteLLM_MemoryRow,
MemoryCreateRequest,
@@ -173,7 +175,7 @@ async def _is_team_admin_for(
)
try:
team_obj = await prisma_client.db.litellm_teamtable.find_unique(
team_obj = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
except Exception as e:
@@ -304,7 +306,7 @@ async def create_memory(
create_data["metadata"] = _serialize_metadata_for_prisma(body.metadata)
try:
row = await prisma_client.db.litellm_memorytable.create(data=create_data)
row = await MemoryRepository(prisma_client).table.create(data=create_data)
except Exception as e:
# Key is globally unique. Any duplicate → 409.
if _is_unique_violation(e):
@@ -364,8 +366,8 @@ async def list_memory(
where = {"AND": [key_filter, vis]}
try:
total = await prisma_client.db.litellm_memorytable.count(where=where)
rows = await prisma_client.db.litellm_memorytable.find_many(
total = await MemoryRepository(prisma_client).table.count(where=where)
rows = await MemoryRepository(prisma_client).table.find_many(
where=where,
order={"updated_at": "desc"},
skip=(page - 1) * page_size,
@@ -386,7 +388,7 @@ async def _find_memory_for_caller(
key_filter: dict = {"key": key}
vis = _visibility_filter(user_api_key_dict)
where: dict = key_filter if vis is None else {"AND": [key_filter, vis]}
rows = await prisma_client.db.litellm_memorytable.find_many(
rows = await MemoryRepository(prisma_client).table.find_many(
where=where, take=1, order={"updated_at": "desc"}
)
if not rows:
@@ -475,7 +477,7 @@ async def upsert_memory(
# their team) — otherwise a teammate could overwrite a personal
# entry through the OR-based visibility filter.
await _assert_write_access(prisma_client, existing, user_api_key_dict)
row = await prisma_client.db.litellm_memorytable.update(
row = await MemoryRepository(prisma_client).table.update(
where={"memory_id": existing.memory_id},
data=data,
)
@@ -503,7 +505,7 @@ async def upsert_memory(
if body.metadata is not None:
create_data["metadata"] = _serialize_metadata_for_prisma(body.metadata)
try:
row = await prisma_client.db.litellm_memorytable.create(
row = await MemoryRepository(prisma_client).table.create(
data=create_data
)
except Exception as e:
@@ -524,7 +526,7 @@ async def upsert_memory(
await _assert_write_access(
prisma_client, existing_after_race, user_api_key_dict
)
row = await prisma_client.db.litellm_memorytable.update(
row = await MemoryRepository(prisma_client).table.update(
where={"memory_id": existing_after_race.memory_id},
data=data,
)
@@ -554,7 +556,7 @@ async def delete_memory(
# Visibility != write authority — see the upsert handler for the rationale.
await _assert_write_access(prisma_client, row, user_api_key_dict)
try:
await prisma_client.db.litellm_memorytable.delete(
await MemoryRepository(prisma_client).table.delete(
where={"memory_id": row.memory_id}
)
except Exception as e:
@@ -5,6 +5,10 @@ from dataclasses import dataclass, field
from types import MappingProxyType
from typing import TYPE_CHECKING, List, Literal, Optional, Union
from litellm.repositories.table_repositories import (
ManagedFileRepository,
ManagedObjectRepository,
)
from litellm.types.utils import SpecialEnums
if TYPE_CHECKING:
@@ -697,7 +701,7 @@ async def resolve_input_file_id_to_unified(response, prisma_client) -> None:
and prisma_client
):
try:
managed_file = await prisma_client.db.litellm_managedfiletable.find_first(
managed_file = await ManagedFileRepository(prisma_client).table.find_first(
where={"flat_model_file_ids": {"has": response.input_file_id}}
)
if managed_file:
@@ -719,7 +723,7 @@ async def resolve_output_file_ids_to_unified(response, prisma_client) -> None:
if not raw_id or _is_base64_encoded_unified_file_id(raw_id):
continue
try:
managed_file = await prisma_client.db.litellm_managedfiletable.find_first(
managed_file = await ManagedFileRepository(prisma_client).table.find_first(
where={"flat_model_file_ids": {"has": raw_id}}
)
if managed_file:
@@ -821,6 +825,7 @@ async def get_batch_from_database(
- response_batch: Parsed LiteLLMBatch object (or None)
"""
import json
from litellm.types.utils import LiteLLMBatch
if managed_files_obj is None or not unified_batch_id:
@@ -830,7 +835,7 @@ async def get_batch_from_database(
if not prisma_client:
return None, None
db_batch_object = await prisma_client.db.litellm_managedobjecttable.find_first(
db_batch_object = await ManagedObjectRepository(prisma_client).table.find_first(
where={"unified_object_id": batch_id}
)
@@ -942,7 +947,7 @@ async def update_batch_in_database(
update_data["batch_processed"] = True
try:
await prisma_client.db.litellm_managedobjecttable.update(
await ManagedObjectRepository(prisma_client).table.update(
where={"unified_object_id": batch_id},
data=update_data,
)
@@ -958,7 +963,7 @@ async def update_batch_in_database(
f"batch_processed column not found, retrying update without it: {col_err}"
)
update_data.pop("batch_processed", None)
await prisma_client.db.litellm_managedobjecttable.update(
await ManagedObjectRepository(prisma_client).table.update(
where={"unified_object_id": batch_id},
data=update_data,
)
@@ -21,6 +21,7 @@ from fastapi import (
UploadFile,
status,
)
import litellm
from litellm import CreateFileRequest, get_secret_str
from litellm._logging import verbose_proxy_logger
@@ -37,15 +38,6 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_headers,
get_custom_llm_provider_from_request_query,
)
from litellm.proxy.utils import ProxyLogging, is_known_model
from litellm.router import Router
from litellm.types.llms.openai import (
CREATE_FILE_REQUESTS_PURPOSE,
FileExpiresAfter,
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
encode_file_id_with_model,
@@ -54,6 +46,15 @@ from litellm.proxy.openai_files_endpoints.common_utils import (
handle_model_based_routing,
prepare_data_with_credentials,
)
from litellm.proxy.utils import ProxyLogging, is_known_model
from litellm.repositories.table_repositories import ManagedFileRepository
from litellm.router import Router
from litellm.types.llms.openai import (
CREATE_FILE_REQUESTS_PURPOSE,
FileExpiresAfter,
OpenAIFileObject,
OpenAIFilesPurpose,
)
router = APIRouter()
@@ -666,7 +667,7 @@ async def get_file_content( # noqa: PLR0915
managed_files_obj, "prisma_client", None
):
prisma_client = getattr(managed_files_obj, "prisma_client")
db_file = await prisma_client.db.litellm_managedfiletable.find_first(
db_file = await ManagedFileRepository(prisma_client).table.find_first(
where={"unified_file_id": file_id}
)
if db_file and db_file.storage_backend and db_file.storage_url:
@@ -43,6 +43,10 @@ from litellm.llms.base_llm.managed_resources.isolation import (
can_access_resource,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.repositories.table_repositories import (
ManagedFileRepository,
ManagedObjectRepository,
)
from litellm.types.llms.openai import OpenAIFileObject
from .managed_id_codec import ManagedIdPayload, decode, is_managed, new_managed_id
@@ -323,7 +327,7 @@ async def _resolve_one(
)
if not found and prisma_client is not None:
try:
db_row = await prisma_client.db.litellm_managedfiletable.find_first(
db_row = await ManagedFileRepository(prisma_client).table.find_first(
where={"unified_file_id": managed_id}
)
if db_row is not None:
@@ -339,7 +343,7 @@ async def _resolve_one(
# Object table (batches, responses)
if prisma_client is not None:
try:
obj_row = await prisma_client.db.litellm_managedobjecttable.find_first(
obj_row = await ManagedObjectRepository(prisma_client).table.find_first(
where={"unified_object_id": managed_id}
)
if obj_row is not None:
@@ -399,7 +403,7 @@ async def _guard_raw_provider_id(
# id and scope to the current provider in the application layer (same as
# _mint_or_reuse_file's dedup).
try:
candidates = await prisma_client.db.litellm_managedfiletable.find_many(
candidates = await ManagedFileRepository(prisma_client).table.find_many(
where={"flat_model_file_ids": {"has": raw_id}},
)
except Exception:
@@ -425,7 +429,7 @@ async def _guard_raw_provider_id(
# Object rows store model_object_id as "passthrough:{provider}:{raw}", so
# the lookup is exact and already provider-scoped.
try:
existing = await prisma_client.db.litellm_managedobjecttable.find_first(
existing = await ManagedObjectRepository(prisma_client).table.find_first(
where={"model_object_id": f"passthrough:{provider}:{raw_id}"}
)
except Exception:
@@ -492,7 +496,7 @@ async def _mint_or_reuse_file(
# reuse a stable row instead of minting duplicate rows on every call.
if prisma_client is not None:
try:
candidates = await prisma_client.db.litellm_managedfiletable.find_many(
candidates = await ManagedFileRepository(prisma_client).table.find_many(
where={"flat_model_file_ids": {"has": raw_id}},
order={"created_at": "asc"},
)
@@ -627,7 +631,7 @@ async def _mint_or_reuse_object(
# the batch's latest state (e.g. output_file_id / error_file_id that
# were null at creation but populated once the batch completed).
try:
await prisma_client.db.litellm_managedobjecttable.update(
await ManagedObjectRepository(prisma_client).table.update(
where={"unified_object_id": existing.unified_object_id},
data={
"file_object": json.dumps(body_snapshot),
@@ -647,7 +651,7 @@ async def _mint_or_reuse_object(
# Dedup: look up by the namespaced key — guaranteed unique per provider.
try:
existing = await prisma_client.db.litellm_managedobjecttable.find_first(
existing = await ManagedObjectRepository(prisma_client).table.find_first(
where={"model_object_id": namespaced_model_object_id}
)
except Exception:
@@ -666,7 +670,7 @@ async def _mint_or_reuse_object(
raw_id.split("_", 1)[0],
)
try:
await prisma_client.db.litellm_managedobjecttable.upsert(
await ManagedObjectRepository(prisma_client).table.upsert(
where={"unified_object_id": managed_id},
data={
"create": {
@@ -690,7 +694,7 @@ async def _mint_or_reuse_object(
# the winner's managed ID so both callers converge on one ID instead of
# the loser silently keeping the raw id.
try:
raced = await prisma_client.db.litellm_managedobjecttable.find_first(
raced = await ManagedObjectRepository(prisma_client).table.find_first(
where={"model_object_id": namespaced_model_object_id}
)
except Exception:
@@ -883,9 +887,9 @@ async def _build_list_where_with_cursor(
return where, fetch_order
cursor_table = (
prisma_client.db.litellm_managedfiletable
ManagedFileRepository(prisma_client).table
if resource_kind == "files"
else prisma_client.db.litellm_managedobjecttable
else ManagedObjectRepository(prisma_client).table
)
cursor_field = (
"unified_file_id" if resource_kind == "files" else "unified_object_id"
@@ -932,12 +936,12 @@ async def _fetch_list_rows(
# across rows that share a created_at timestamp.
try:
if resource_kind == "files":
return await prisma_client.db.litellm_managedfiletable.find_many(
return await ManagedFileRepository(prisma_client).table.find_many(
where=where,
order=[{"created_at": fetch_order}, {"unified_file_id": fetch_order}],
take=fetch_limit,
)
return await prisma_client.db.litellm_managedobjecttable.find_many(
return await ManagedObjectRepository(prisma_client).table.find_many(
where={**where, "file_purpose": "batch"},
order=[{"created_at": fetch_order}, {"unified_object_id": fetch_order}],
take=fetch_limit,
@@ -60,12 +60,13 @@ from litellm.proxy.common_utils.http_parsing_utils import (
)
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.proxy.utils import normalize_route_for_root_path
from litellm.repositories.team_repository import TeamRepository
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
EndpointType,
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY,
EndpointType,
PassthroughStandardLoggingPayload,
)
@@ -1002,9 +1003,7 @@ async def pass_through_request( # noqa: PLR0915
is_passthrough_list_route,
list_passthrough_ids_from_db,
)
from litellm.proxy.proxy_server import (
prisma_client as _list_prisma,
)
from litellm.proxy.proxy_server import prisma_client as _list_prisma
if (
is_passthrough_list_route(
@@ -2875,7 +2874,7 @@ async def _filter_endpoints_by_team_allowed_routes(
HTTPException: If team is not found
"""
# retrieve team from db
team = await prisma_client.db.litellm_teamtable.find_unique(
team = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id},
)
if team is None:
@@ -9,6 +9,7 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.repositories.table_repositories import PolicyAttachmentRepository
from litellm.types.proxy.policy_engine import (
PolicyAttachment,
PolicyAttachmentCreateRequest,
@@ -278,21 +279,21 @@ class AttachmentRegistry:
PolicyAttachmentDBResponse with the created attachment
"""
try:
created_attachment = (
await prisma_client.db.litellm_policyattachmenttable.create(
data={
"policy_name": attachment_request.policy_name,
"scope": attachment_request.scope,
"teams": attachment_request.teams or [],
"keys": attachment_request.keys or [],
"models": attachment_request.models or [],
"tags": attachment_request.tags or [],
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
"created_by": created_by,
"updated_by": created_by,
}
)
created_attachment = await PolicyAttachmentRepository(
prisma_client
).table.create(
data={
"policy_name": attachment_request.policy_name,
"scope": attachment_request.scope,
"teams": attachment_request.teams or [],
"keys": attachment_request.keys or [],
"models": attachment_request.models or [],
"tags": attachment_request.tags or [],
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
"created_by": created_by,
"updated_by": created_by,
}
)
# Also add to in-memory registry
@@ -340,17 +341,15 @@ class AttachmentRegistry:
"""
try:
# Get attachment before deleting
attachment = (
await prisma_client.db.litellm_policyattachmenttable.find_unique(
where={"attachment_id": attachment_id}
)
)
attachment = await PolicyAttachmentRepository(
prisma_client
).table.find_unique(where={"attachment_id": attachment_id})
if attachment is None:
raise Exception(f"Attachment with ID {attachment_id} not found")
# Delete from DB
await prisma_client.db.litellm_policyattachmenttable.delete(
await PolicyAttachmentRepository(prisma_client).table.delete(
where={"attachment_id": attachment_id}
)
@@ -379,11 +378,9 @@ class AttachmentRegistry:
PolicyAttachmentDBResponse if found, None otherwise
"""
try:
attachment = (
await prisma_client.db.litellm_policyattachmenttable.find_unique(
where={"attachment_id": attachment_id}
)
)
attachment = await PolicyAttachmentRepository(
prisma_client
).table.find_unique(where={"attachment_id": attachment_id})
if attachment is None:
return None
@@ -419,10 +416,10 @@ class AttachmentRegistry:
List of PolicyAttachmentDBResponse objects
"""
try:
attachments = (
await prisma_client.db.litellm_policyattachmenttable.find_many(
order={"created_at": "desc"},
)
attachments = await PolicyAttachmentRepository(
prisma_client
).table.find_many(
order={"created_at": "desc"},
)
return [
+22 -21
View File
@@ -12,6 +12,7 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from litellm._logging import verbose_proxy_logger
from litellm.repositories.table_repositories import PolicyRepository
from litellm.types.proxy.policy_engine import (
GuardrailPipeline,
PipelineStep,
@@ -295,7 +296,7 @@ class PolicyRegistry:
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
data["pipeline"] = json.dumps(validated_pipeline.model_dump())
created_policy = await prisma_client.db.litellm_policytable.create(
created_policy = await PolicyRepository(prisma_client).table.create(
data=data
)
@@ -347,7 +348,7 @@ class PolicyRegistry:
Exception: If policy is not in draft status (only drafts are editable).
"""
try:
existing = await prisma_client.db.litellm_policytable.find_unique(
existing = await PolicyRepository(prisma_client).table.find_unique(
where={"policy_id": policy_id}
)
if existing is None:
@@ -382,7 +383,7 @@ class PolicyRegistry:
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
update_data["pipeline"] = json.dumps(validated_pipeline.model_dump())
updated_policy = await prisma_client.db.litellm_policytable.update(
updated_policy = await PolicyRepository(prisma_client).table.update(
where={"policy_id": policy_id},
data=update_data,
)
@@ -413,7 +414,7 @@ class PolicyRegistry:
Dict with "message" and optional "warning" if production was deleted.
"""
try:
policy = await prisma_client.db.litellm_policytable.find_unique(
policy = await PolicyRepository(prisma_client).table.find_unique(
where={"policy_id": policy_id}
)
@@ -424,7 +425,7 @@ class PolicyRegistry:
policy_name = policy.policy_name
# Delete from DB
await prisma_client.db.litellm_policytable.delete(
await PolicyRepository(prisma_client).table.delete(
where={"policy_id": policy_id}
)
@@ -461,7 +462,7 @@ class PolicyRegistry:
PolicyDBResponse if found, None otherwise
"""
try:
policy = await prisma_client.db.litellm_policytable.find_unique(
policy = await PolicyRepository(prisma_client).table.find_unique(
where={"policy_id": policy_id}
)
@@ -512,7 +513,7 @@ class PolicyRegistry:
if version_status is not None:
where["version_status"] = version_status
policies = await prisma_client.db.litellm_policytable.find_many(
policies = await PolicyRepository(prisma_client).table.find_many(
where=where if where else None,
order={"created_at": "desc"},
)
@@ -554,7 +555,7 @@ class PolicyRegistry:
self.add_policy(policy_response.policy_name, policy)
self._policies_by_id = {}
non_production = await prisma_client.db.litellm_policytable.find_many(
non_production = await PolicyRepository(prisma_client).table.find_many(
where={"version_status": {"in": ["draft", "published"]}},
order={"created_at": "desc"},
)
@@ -654,7 +655,7 @@ class PolicyRegistry:
PolicyVersionListResponse with policy_name and list of versions
"""
try:
rows = await prisma_client.db.litellm_policytable.find_many(
rows = await PolicyRepository(prisma_client).table.find_many(
where={"policy_name": policy_name},
order={"version_number": "desc"},
)
@@ -690,7 +691,7 @@ class PolicyRegistry:
"""
try:
if source_policy_id is not None:
source = await prisma_client.db.litellm_policytable.find_unique(
source = await PolicyRepository(prisma_client).table.find_unique(
where={"policy_id": source_policy_id}
)
if source is None:
@@ -701,7 +702,7 @@ class PolicyRegistry:
)
else:
# Find current production version for this policy_name
prod = await prisma_client.db.litellm_policytable.find_first(
prod = await PolicyRepository(prisma_client).table.find_first(
where={
"policy_name": policy_name,
"version_status": "production",
@@ -714,7 +715,7 @@ class PolicyRegistry:
source = prod
# Next version number
latest = await prisma_client.db.litellm_policytable.find_first(
latest = await PolicyRepository(prisma_client).table.find_first(
where={"policy_name": policy_name},
order={"version_number": "desc"},
)
@@ -722,7 +723,7 @@ class PolicyRegistry:
now = datetime.now(timezone.utc)
# Set is_latest=False on all existing versions for this policy_name
await prisma_client.db.litellm_policytable.update_many(
await PolicyRepository(prisma_client).table.update_many(
where={"policy_name": policy_name},
data={"is_latest": False},
)
@@ -758,7 +759,7 @@ class PolicyRegistry:
else source.pipeline
)
created = await prisma_client.db.litellm_policytable.create(data=data)
created = await PolicyRepository(prisma_client).table.create(data=data)
return _row_to_policy_db_response(created)
except Exception as e:
verbose_proxy_logger.exception(f"Error creating new version: {e}")
@@ -794,7 +795,7 @@ class PolicyRegistry:
f"Invalid status '{new_status}'. Use 'published' or 'production'."
)
row = await prisma_client.db.litellm_policytable.find_unique(
row = await PolicyRepository(prisma_client).table.find_unique(
where={"policy_id": policy_id}
)
if row is None:
@@ -809,7 +810,7 @@ class PolicyRegistry:
raise Exception(
f"Only draft versions can be published. Current status: '{current}'."
)
updated = await prisma_client.db.litellm_policytable.update(
updated = await PolicyRepository(prisma_client).table.update(
where={"policy_id": policy_id},
data={
"version_status": "published",
@@ -832,7 +833,7 @@ class PolicyRegistry:
)
# Demote current production to published
await prisma_client.db.litellm_policytable.update_many(
await PolicyRepository(prisma_client).table.update_many(
where={
"policy_name": policy_name,
"version_status": "production",
@@ -845,7 +846,7 @@ class PolicyRegistry:
)
# Promote this version to production
updated = await prisma_client.db.litellm_policytable.update(
updated = await PolicyRepository(prisma_client).table.update(
where={"policy_id": policy_id},
data={
"version_status": "production",
@@ -895,10 +896,10 @@ class PolicyRegistry:
PolicyVersionCompareResponse with both versions and field_diffs
"""
try:
a = await prisma_client.db.litellm_policytable.find_unique(
a = await PolicyRepository(prisma_client).table.find_unique(
where={"policy_id": policy_id_a}
)
b = await prisma_client.db.litellm_policytable.find_unique(
b = await PolicyRepository(prisma_client).table.find_unique(
where={"policy_id": policy_id_b}
)
if a is None:
@@ -950,7 +951,7 @@ class PolicyRegistry:
Dict with success message
"""
try:
await prisma_client.db.litellm_policytable.delete_many(
await PolicyRepository(prisma_client).table.delete_many(
where={"policy_name": policy_name}
)
self.remove_policy(policy_name)
@@ -16,6 +16,10 @@ from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.types.proxy.policy_engine import (
AttachmentImpactResponse,
PolicyAttachmentCreateRequest,
@@ -76,7 +80,7 @@ def _get_tags_from_metadata(metadata: object, json_metadata: object = None) -> l
async def _fetch_all_teams(prisma_client: object) -> list:
"""Fetch teams from DB once. Reuse the result across tag and alias lookups."""
return await prisma_client.db.litellm_teamtable.find_many( # type: ignore
return await TeamRepository(prisma_client).table.find_many( # type: ignore
where={},
order={"created_at": "desc"},
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
@@ -159,7 +163,7 @@ async def _find_affected_by_team_patterns(
new_keys: list = []
unnamed_keys_count = 0
if matched_team_ids:
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
keys = await VerificationTokenRepository(prisma_client).table.find_many( # type: ignore
where={"team_id": {"in": matched_team_ids}},
order={"created_at": "desc"},
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
@@ -182,7 +186,7 @@ async def _find_affected_keys_by_alias(
affected: list = []
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
keys = await VerificationTokenRepository(prisma_client).table.find_many( # type: ignore
where=_build_alias_where("key_alias", key_patterns),
order={"created_at": "desc"},
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
@@ -367,7 +371,7 @@ async def estimate_attachment_impact(
# Tag-based impact
if tag_patterns:
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
keys = await VerificationTokenRepository(prisma_client).table.find_many( # type: ignore
where={},
order={"created_at": "desc"},
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
@@ -12,6 +12,10 @@ Validates:
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from litellm._logging import verbose_proxy_logger
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
from litellm.types.proxy.policy_engine import (
Policy,
PolicyValidationError,
@@ -95,7 +99,7 @@ class PolicyValidator:
return True # Can't validate without DB, assume valid
try:
team = await self.prisma_client.db.litellm_teamtable.find_first(
team = await TeamRepository(self.prisma_client).table.find_first(
where={"team_alias": team_alias},
)
return team is not None
@@ -119,7 +123,9 @@ class PolicyValidator:
return True # Can't validate without DB, assume valid
try:
key = await self.prisma_client.db.litellm_verificationtoken.find_first(
key = await VerificationTokenRepository(
self.prisma_client
).table.find_first(
where={"key_alias": key_alias},
)
return key is not None
+11 -10
View File
@@ -22,6 +22,7 @@ from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKey
from litellm.proxy.auth.auth_utils import is_request_body_safe
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.path_utils import safe_filename
from litellm.repositories.table_repositories import PromptRepository
from litellm.types.prompts.init_prompts import (
ListPromptsResponse,
PromptInfo,
@@ -208,7 +209,7 @@ async def get_next_version_for_prompt(
Returns:
Next version number (1 if no versions exist, max_version + 1 otherwise)
"""
existing_prompts = await prisma_client.db.litellm_prompttable.find_many(
existing_prompts = await PromptRepository(prisma_client).table.find_many(
where={"prompt_id": prompt_id, "environment": environment}
)
@@ -441,7 +442,7 @@ async def get_prompt_versions(
where_clause: Dict[str, Any] = {"prompt_id": base_prompt_id}
if environment:
where_clause["environment"] = environment
db_prompts = await prisma_client.db.litellm_prompttable.find_many(
db_prompts = await PromptRepository(prisma_client).table.find_many(
where=where_clause,
order={"version": "desc"},
)
@@ -612,7 +613,7 @@ async def get_prompt_info(
# Query all environments this prompt exists in (lightweight: distinct on environment)
all_environments: List[str] = []
if prisma_client is not None:
all_prompt_rows = await prisma_client.db.litellm_prompttable.find_many(
all_prompt_rows = await PromptRepository(prisma_client).table.find_many(
where={"prompt_id": base_prompt_id},
distinct=["environment"],
)
@@ -634,7 +635,7 @@ async def get_prompt_info(
}
if requested_version is not None:
where_clause["version"] = requested_version
env_prompts = await prisma_client.db.litellm_prompttable.find_many(
env_prompts = await PromptRepository(prisma_client).table.find_many(
where=where_clause,
order={"version": "desc"},
take=1,
@@ -752,7 +753,7 @@ async def create_prompt(
)
# Store prompt in db with version
prompt_db_entry = await prisma_client.db.litellm_prompttable.create(
prompt_db_entry = await PromptRepository(prisma_client).table.create(
data={
"prompt_id": request.prompt_id,
"version": new_version,
@@ -848,7 +849,7 @@ async def update_prompt(
)
# Check if any version of this prompt exists (in any environment)
existing_prompts = await prisma_client.db.litellm_prompttable.find_many(
existing_prompts = await PromptRepository(prisma_client).table.find_many(
where={"prompt_id": base_prompt_id}
)
@@ -877,7 +878,7 @@ async def update_prompt(
)
# Store new version in db
prompt_db_entry = await prisma_client.db.litellm_prompttable.create(
prompt_db_entry = await PromptRepository(prisma_client).table.create(
data={
"prompt_id": base_prompt_id,
"version": new_version,
@@ -993,7 +994,7 @@ async def delete_prompt(
delete_where["environment"] = environment
# Delete versions from the database (scoped to environment if provided)
await prisma_client.db.litellm_prompttable.delete_many(where=delete_where)
await PromptRepository(prisma_client).table.delete_many(where=delete_where)
# Remove matching prompts from memory — scope to environment if provided
if environment:
@@ -1105,7 +1106,7 @@ async def patch_prompt(
if requested_version is not None:
find_where["version"] = requested_version
db_rows = await prisma_client.db.litellm_prompttable.find_many(
db_rows = await PromptRepository(prisma_client).table.find_many(
where=find_where,
order={"version": "desc"},
take=1,
@@ -1163,7 +1164,7 @@ async def patch_prompt(
update_data["created_by"] = user_api_key_dict.user_id
# Update by primary key (id) to target exactly one row
updated_prompt_db_entry = await prisma_client.db.litellm_prompttable.update(
updated_prompt_db_entry = await PromptRepository(prisma_client).table.update(
where={"id": target_row.id},
data=update_data,
)
+92 -73
View File
@@ -48,6 +48,7 @@ from litellm.constants import (
AIOHTTP_TTL_DNS_CACHE,
AUDIO_SPEECH_CHUNK_SIZE,
BASE_MCP_ROUTE,
DAILY_TAG_SPEND_BATCH_MULTIPLIER,
DEFAULT_MAX_RECURSE_DEPTH,
DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL,
DEFAULT_SHARED_HEALTH_CHECK_TTL,
@@ -56,13 +57,13 @@ from litellm.constants import (
LITELLM_SETTINGS_SAFE_DB_OVERRIDES,
LITELLM_UI_ALLOW_HEADERS,
LITELLM_UI_SESSION_DURATION,
DAILY_TAG_SPEND_BATCH_MULTIPLIER,
)
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import (
UI_TEAM_ID,
CallbackDelete,
CallInfo,
CommonProxyErrors,
@@ -79,8 +80,8 @@ from litellm.proxy._types import (
InvitationModel,
InvitationNew,
InvitationUpdate,
Litellm_EntityType,
LiteLLM_EndUserTable,
Litellm_EntityType,
LiteLLM_JWTAuth,
LiteLLM_TagTable,
LiteLLM_TeamTable,
@@ -96,7 +97,6 @@ from litellm.proxy._types import (
TeamDefaultSettings,
TokenCountRequest,
TransformRequestBody,
UI_TEAM_ID,
UserAPIKeyAuth,
)
from litellm.proxy.common_utils.cache_pydantic_utils import CacheCodec
@@ -212,8 +212,6 @@ from litellm import Router
from litellm._logging import verbose_proxy_logger, verbose_router_logger
from litellm.caching.caching import DualCache, RedisCache
from litellm.caching.redis_cluster_cache import RedisClusterCache
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.constants import (
_REALTIME_BODY_CACHE_SIZE,
APSCHEDULER_COALESCE,
@@ -247,8 +245,8 @@ from litellm.litellm_core_utils.sensitive_data_masker import (
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.proxy._types import *
from litellm.proxy._lazy_features import attach_lazy_features
from litellm.proxy._types import *
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
router as analytics_router,
)
@@ -308,6 +306,8 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
from litellm.proxy.common_utils.proxy_state import ProxyState
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.container_endpoints.endpoints import router as container_router
from litellm.proxy.credential_endpoints.endpoints import router as credential_router
from litellm.proxy.db.db_transaction_queue.spend_log_cleanup import SpendLogCleanup
@@ -361,7 +361,9 @@ from litellm.proxy.management_endpoints.fallback_management_endpoints import (
from litellm.proxy.management_endpoints.internal_user_endpoints import (
router as internal_user_router,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
from litellm.proxy.management_endpoints.internal_user_endpoints import (
user_update,
)
from litellm.proxy.management_endpoints.key_management_endpoints import (
delete_verification_tokens,
duration_in_seconds,
@@ -398,10 +400,6 @@ from litellm.proxy.management_endpoints.team_endpoints import (
update_team,
validate_membership,
)
from litellm.proxy.management_endpoints.workflow_management_endpoints import (
router as workflow_management_router,
)
from litellm.proxy.memory.memory_endpoints import router as memory_router
from litellm.proxy.management_endpoints.ui_sso import (
get_disabled_non_admin_personal_key_creation,
)
@@ -409,7 +407,11 @@ from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
from litellm.proxy.management_endpoints.user_agent_analytics_endpoints import (
router as user_agent_analytics_router,
)
from litellm.proxy.management_endpoints.workflow_management_endpoints import (
router as workflow_management_router,
)
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
from litellm.proxy.memory.memory_endpoints import router as memory_router
from litellm.proxy.middleware.in_flight_requests_middleware import (
InFlightRequestsMiddleware,
)
@@ -417,12 +419,13 @@ from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMi
from litellm.proxy.middleware.request_size_limit_middleware import (
RequestSizeLimitMiddleware,
)
from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager
from litellm.proxy.ocr_endpoints.endpoints import router as ocr_router
from litellm.proxy.openai_files_endpoints.files_endpoints import (
router as openai_files_router,
)
from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
from litellm.proxy.openai_files_endpoints.files_endpoints import (
set_files_config,
)
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
passthrough_endpoint_router,
)
@@ -444,6 +447,7 @@ from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router
from litellm.proxy.response_api_endpoints.endpoints import router as response_router
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.search_endpoints.endpoints import router as search_router
from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager
from litellm.proxy.spend_tracking.spend_management_endpoints import (
router as spend_management_router,
)
@@ -478,6 +482,7 @@ from litellm.proxy.utils import (
update_spend,
)
from litellm.proxy.video_endpoints.endpoints import router as video_router
from litellm.repositories.credentials_repository import CredentialsRepository
from litellm.router import (
AssistantsTypedDict,
Deployment,
@@ -511,7 +516,9 @@ from litellm.types.proxy.management_endpoints.ui_sso import (
LiteLLM_UpperboundKeyGenerateParams,
)
from litellm.types.realtime import RealtimeQueryParams
from litellm.types.router import DeploymentTypedDict
from litellm.types.router import (
DeploymentTypedDict,
)
from litellm.types.router import ModelInfo as RouterModelInfo
from litellm.types.router import (
RouterGeneralSettings,
@@ -4248,13 +4255,13 @@ class ProxyConfig:
)
setattr(litellm, key, value)
if key in {"s3_audit_callback_params", "s3_callback_params"}:
from litellm.proxy.management_helpers.audit_logs import (
reset_audit_log_callback_cache,
)
from litellm.integrations.s3_v2 import S3Logger as S3V2Logger
from litellm.litellm_core_utils.litellm_logging import (
_in_memory_loggers,
)
from litellm.integrations.s3_v2 import S3Logger as S3V2Logger
from litellm.proxy.management_helpers.audit_logs import (
reset_audit_log_callback_cache,
)
reset_audit_log_callback_cache()
_in_memory_loggers[:] = [
@@ -5335,7 +5342,7 @@ class ProxyConfig:
4. Update router settings
"""
if llm_router is not None and prisma_client is not None:
db_router_settings = await prisma_client.db.litellm_config.find_first(
db_router_settings = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": "router_settings"}
)
@@ -5761,7 +5768,7 @@ class ProxyConfig:
async def _get_models_from_db(self, prisma_client: PrismaClient) -> list:
try:
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
new_models = await ModelRepository(prisma_client).table.find_many()
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy_server.py::add_deployment() - Error getting new models from DB - {}".format(
@@ -5975,7 +5982,7 @@ class ProxyConfig:
"""
try:
sso_settings = await prisma_client.db.litellm_ssoconfig.find_unique(
sso_settings = await SSOConfigRepository(prisma_client).table.find_unique(
where={"id": "sso_config"}
)
if sso_settings is not None:
@@ -6011,9 +6018,9 @@ class ProxyConfig:
)
try:
db_record = await prisma_client.db.litellm_configoverrides.find_unique(
where={"config_type": "hashicorp_vault"}
)
db_record = await ConfigOverridesRepository(
prisma_client
).table.find_unique(where={"config_type": "hashicorp_vault"})
if db_record is None or db_record.config_value is None:
if self._last_hashicorp_vault_config is not None:
@@ -6130,7 +6137,7 @@ class ProxyConfig:
last_model_cost_map_reload = current_time.isoformat()
# Clear force reload flag in database
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "model_cost_map_reload_config"},
data={
"create": {
@@ -6239,7 +6246,7 @@ class ProxyConfig:
last_anthropic_beta_headers_reload = current_time.isoformat()
# Clear force reload flag in database
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "anthropic_beta_headers_reload_config"},
data={
"create": {
@@ -6299,7 +6306,7 @@ class ProxyConfig:
from litellm.types.prompts.init_prompts import PromptSpec
try:
prompts_in_db = await prisma_client.db.litellm_prompttable.find_many()
prompts_in_db = await PromptRepository(prisma_client).table.find_many()
for prompt in prompts_in_db:
# Convert DB object to dict and create versioned prompt_id
prompt_spec = self._get_prompt_spec_for_db_prompt(db_prompt=prompt)
@@ -6589,7 +6596,7 @@ class ProxyConfig:
async def get_credentials(self, prisma_client: PrismaClient):
try:
credentials = await prisma_client.db.litellm_credentialstable.find_many()
credentials = await CredentialsRepository(prisma_client).find_all()
credentials = [self.decrypt_credentials(cred) for cred in credentials]
await self.delete_credentials(
credentials
@@ -7352,7 +7359,7 @@ class ProxyStartupEvent:
# spend cap blocks forever once it's hit.
if prisma_client is not None and litellm.budget_duration is not None:
try:
await prisma_client.db.litellm_usertable.update_many(
await UserRepository(prisma_client).table.update_many(
where={
"user_id": litellm_proxy_budget_name,
"budget_reset_at": None,
@@ -7420,7 +7427,7 @@ class ProxyStartupEvent:
if prisma_client is None:
return
db_record = await prisma_client.db.litellm_uisettings.find_unique(
db_record = await UISettingsRepository(prisma_client).table.find_unique(
where={"id": "ui_settings"}
)
if db_record and db_record.ui_settings:
@@ -7569,7 +7576,7 @@ class ProxyStartupEvent:
# but YAML config has False.
if store_model_in_db is not True and prisma_client is not None:
try:
_db_gs_record = await prisma_client.db.litellm_config.find_first(
_db_gs_record = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": "general_settings"}
)
if _db_gs_record is not None and isinstance(
@@ -10361,6 +10368,18 @@ async def run_thread(
# )
# async def get_available_routes(user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
from litellm.llms.base_llm.base_utils import BaseTokenCounter
from litellm.repositories.config_repository import ConfigRepository
from litellm.repositories.model_repository import ModelRepository
from litellm.repositories.table_repositories import (
AccessGroupRepository,
ConfigOverridesRepository,
InvitationLinkRepository,
PromptRepository,
SSOConfigRepository,
UISettingsRepository,
)
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.user_repository import UserRepository
def _get_provider_token_counter(
@@ -10666,7 +10685,7 @@ async def _check_if_model_is_user_added(
id = model.get("model_info", {}).get("id", None)
if id is None:
continue
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
db_model = await ModelRepository(prisma_client).table.find_unique(
where={"model_id": id}
)
if db_model is not None:
@@ -10723,7 +10742,7 @@ async def non_admin_all_models(
if user_api_key_dict.user_id:
try:
user_row = await prisma_client.db.litellm_usertable.find_unique(
user_row = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_api_key_dict.user_id}
)
except Exception:
@@ -10823,7 +10842,7 @@ async def _add_access_group_models_to_team_models(
return team_models
# Single batch fetch for all access groups
access_group_rows = await prisma_client.db.litellm_accessgrouptable.find_many(
access_group_rows = await AccessGroupRepository(prisma_client).table.find_many(
where={"access_group_id": {"in": list(all_access_group_ids)}}
)
ag_model_map: Dict[str, List[str]] = {
@@ -10865,13 +10884,13 @@ async def get_all_team_models(
team_db_objects_typed: List[LiteLLM_TeamTable] = []
if user_teams == "*":
team_db_objects = await prisma_client.db.litellm_teamtable.find_many()
team_db_objects = await TeamRepository(prisma_client).table.find_many()
team_db_objects_typed = [
LiteLLM_TeamTable(**team_db_object.model_dump())
for team_db_object in team_db_objects
]
else:
team_db_objects = await prisma_client.db.litellm_teamtable.find_many(
team_db_objects = await TeamRepository(prisma_client).table.find_many(
where={"team_id": {"in": user_teams}}
)
@@ -10938,7 +10957,7 @@ async def get_all_team_and_direct_access_models(
exclude_team_models=True
) # has access to all models
elif user_api_key_dict.user_id is not None:
user_db_object = await prisma_client.db.litellm_usertable.find_unique(
user_db_object = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_api_key_dict.user_id}
)
if user_db_object is not None:
@@ -11082,7 +11101,7 @@ async def _get_caller_byok_team_scope(
if user_id is None:
return set()
try:
user_row = await prisma_client.db.litellm_usertable.find_unique(
user_row = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_id}
)
except Exception:
@@ -11143,13 +11162,13 @@ async def _fetch_db_models_for_search(
else:
take_limit = max(0, page * size - router_models_count)
db_models_total_count = await prisma_client.db.litellm_proxymodeltable.count(
db_models_total_count = await ModelRepository(prisma_client).table.count(
where=db_where_condition
)
db_models_raw: list = []
if take_limit > 0:
db_models_raw = await prisma_client.db.litellm_proxymodeltable.find_many(
db_models_raw = await ModelRepository(prisma_client).table.find_many(
where=db_where_condition,
take=take_limit,
)
@@ -11484,7 +11503,7 @@ async def _load_team_object_for_model_filter(
) -> Optional[LiteLLM_TeamTable]:
"""Load team row from DB; returns None if missing or on error."""
try:
team_db_object = await prisma_client.db.litellm_teamtable.find_unique(
team_db_object = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
if team_db_object is None:
@@ -11547,7 +11566,7 @@ async def _gather_team_accessible_model_ids(
_resolved_names = _team_models_resolve_to_names(
team_object.models, access_groups
)
db_models = await prisma_client.db.litellm_proxymodeltable.find_many(
db_models = await ModelRepository(prisma_client).table.find_many(
where={"model_name": {"in": _resolved_names}}
)
for db_model in db_models:
@@ -11586,7 +11605,7 @@ async def _authorize_team_id_query(
detail={"error": "Not authorized to view this team's models"},
)
try:
user_row = await prisma_client.db.litellm_usertable.find_unique(
user_row = await UserRepository(prisma_client).table.find_unique(
where={"user_id": user_id}
)
except Exception:
@@ -11693,7 +11712,7 @@ async def _find_model_by_id(
# If not found in config, search in database
if found_model is None:
try:
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
db_model = await ModelRepository(prisma_client).table.find_unique(
where={"model_id": model_id}
)
if db_model:
@@ -12845,7 +12864,7 @@ async def alerting_settings(
)
## get general settings from db
db_general_settings = await prisma_client.db.litellm_config.find_first(
db_general_settings = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": "general_settings"}
)
@@ -13384,7 +13403,7 @@ async def onboarding(invite_link: str, request: Request):
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
invite_obj = await prisma_client.db.litellm_invitationlink.find_unique(
invite_obj = await InvitationLinkRepository(prisma_client).table.find_unique(
where={"id": invite_link}
)
if invite_obj is None:
@@ -13408,7 +13427,7 @@ async def onboarding(invite_link: str, request: Request):
)
### GET USER OBJECT ###
user_obj = await prisma_client.db.litellm_usertable.find_unique(
user_obj = await UserRepository(prisma_client).table.find_unique(
where={"user_id": invite_obj.user_id}
)
@@ -13513,7 +13532,7 @@ async def _rollback_onboarding_invite_claim(
return
try:
await prisma_client.db.litellm_invitationlink.update_many(
await InvitationLinkRepository(prisma_client).table.update_many(
where={"id": invitation_link, "is_accepted": True},
data={
"accepted_at": None,
@@ -13547,10 +13566,10 @@ async def _generate_onboarding_ui_session_token(user_obj: Any) -> str:
)
key = response["token"] # type: ignore
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
import jwt
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
disabled_non_admin_personal_key_creation = (
get_disabled_non_admin_personal_key_creation()
)
@@ -13596,7 +13615,7 @@ async def claim_onboarding_link(data: InvitationClaim, request: Request):
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
invite_obj = await prisma_client.db.litellm_invitationlink.find_unique(
invite_obj = await InvitationLinkRepository(prisma_client).table.find_unique(
where={"id": data.invitation_link}
)
if invite_obj is None:
@@ -13956,7 +13975,7 @@ async def invitation_info(
},
)
response = await prisma_client.db.litellm_invitationlink.find_unique(
response = await InvitationLinkRepository(prisma_client).table.find_unique(
where={"id": invitation_id}
)
@@ -14010,7 +14029,7 @@ async def invitation_update(
)
current_time = litellm.utils.get_utc_datetime()
response = await prisma_client.db.litellm_invitationlink.update(
response = await InvitationLinkRepository(prisma_client).table.update(
where={"id": data.invitation_id},
data={
"id": data.invitation_id,
@@ -14081,7 +14100,7 @@ async def invitation_delete(
# Org admins can only delete invitations they created
if is_other_admin and not is_proxy_admin:
invitation = await prisma_client.db.litellm_invitationlink.find_unique(
invitation = await InvitationLinkRepository(prisma_client).table.find_unique(
where={"id": data.invitation_id}
)
if invitation is None:
@@ -14097,7 +14116,7 @@ async def invitation_delete(
},
)
response = await prisma_client.db.litellm_invitationlink.delete(
response = await InvitationLinkRepository(prisma_client).table.delete(
where={"id": data.invitation_id}
)
@@ -14139,7 +14158,7 @@ async def update_config( # noqa: PLR0915
raise Exception("No DB Connected")
async def _read_section(param_name: str) -> dict:
row = await prisma_client.db.litellm_config.find_first(
row = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": param_name}
)
if row is None or row.param_value is None:
@@ -14148,7 +14167,7 @@ async def update_config( # noqa: PLR0915
async def _upsert_section(param_name: str, value: dict) -> None:
serialized = json.dumps(value)
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": param_name},
data={
"create": {"param_name": param_name, "param_value": serialized},
@@ -14323,7 +14342,7 @@ async def update_config_general_settings(
)
## get general settings from db
db_general_settings = await prisma_client.db.litellm_config.find_first(
db_general_settings = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": "general_settings"}
)
### update value
@@ -14337,7 +14356,7 @@ async def update_config_general_settings(
general_settings[data.field_name] = data.field_value
response = await prisma_client.db.litellm_config.upsert(
response = await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "general_settings"},
data={
"create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
@@ -14387,7 +14406,7 @@ async def get_config_general_settings(
)
## get general settings from db
db_general_settings = await prisma_client.db.litellm_config.find_first(
db_general_settings = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": "general_settings"}
)
### pop the value
@@ -14450,7 +14469,7 @@ async def get_config_list(
)
## get general settings from db
db_general_settings = await prisma_client.db.litellm_config.find_first(
db_general_settings = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": "general_settings"}
)
@@ -14604,7 +14623,7 @@ async def delete_config_general_settings(
)
## get general settings from db
db_general_settings = await prisma_client.db.litellm_config.find_first(
db_general_settings = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": "general_settings"}
)
### pop the value
@@ -14621,7 +14640,7 @@ async def delete_config_general_settings(
general_settings.pop(data.field_name, None)
response = await prisma_client.db.litellm_config.upsert(
response = await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "general_settings"},
data={
"create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
@@ -14975,14 +14994,14 @@ async def reload_model_cost_map(
last_model_cost_map_reload = current_time.isoformat()
# Set force reload flag in database for other pods, preserving existing interval_hours
existing_config = await prisma_client.db.litellm_config.find_unique(
existing_config = await ConfigRepository(prisma_client).table.find_unique(
where={"param_name": "model_cost_map_reload_config"}
)
existing_interval = None
if existing_config and existing_config.param_value:
existing_interval = existing_config.param_value.get("interval_hours")
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "model_cost_map_reload_config"},
data={
"create": {
@@ -15052,7 +15071,7 @@ async def schedule_model_cost_map_reload(
)
# Update database with new reload configuration
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "model_cost_map_reload_config"},
data={
"create": {
@@ -15119,7 +15138,7 @@ async def cancel_model_cost_map_reload(
)
# Remove reload configuration from database
await prisma_client.db.litellm_config.delete(
await ConfigRepository(prisma_client).table.delete(
where={"param_name": "model_cost_map_reload_config"}
)
await invalidate_config_param("model_cost_map_reload_config")
@@ -15178,7 +15197,7 @@ async def get_model_cost_map_reload_status(
}
# Get reload configuration from database
config_record = await prisma_client.db.litellm_config.find_unique(
config_record = await ConfigRepository(prisma_client).table.find_unique(
where={"param_name": "model_cost_map_reload_config"}
)
@@ -15329,7 +15348,7 @@ async def reload_anthropic_beta_headers(
last_anthropic_beta_headers_reload = current_time.isoformat()
# Set force reload flag in database for other pods, preserving existing interval_hours
existing_beta_config = await prisma_client.db.litellm_config.find_unique(
existing_beta_config = await ConfigRepository(prisma_client).table.find_unique(
where={"param_name": "anthropic_beta_headers_reload_config"}
)
existing_beta_interval = None
@@ -15338,7 +15357,7 @@ async def reload_anthropic_beta_headers(
"interval_hours"
)
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "anthropic_beta_headers_reload_config"},
data={
"create": {
@@ -15412,7 +15431,7 @@ async def schedule_anthropic_beta_headers_reload(
)
# Update database with new reload configuration
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "anthropic_beta_headers_reload_config"},
data={
"create": {
@@ -15479,7 +15498,7 @@ async def cancel_anthropic_beta_headers_reload(
)
# Remove reload configuration from database
await prisma_client.db.litellm_config.delete(
await ConfigRepository(prisma_client).table.delete(
where={"param_name": "anthropic_beta_headers_reload_config"}
)
await invalidate_config_param("anthropic_beta_headers_reload_config")
@@ -15539,7 +15558,7 @@ async def get_anthropic_beta_headers_reload_status(
}
# Get reload configuration from database
config_record = await prisma_client.db.litellm_config.find_unique(
config_record = await ConfigRepository(prisma_client).table.find_unique(
where={"param_name": "anthropic_beta_headers_reload_config"}
)
@@ -4,9 +4,9 @@ import re
from importlib.resources import files
from typing import Any, Dict, List, Optional
import litellm
from fastapi import APIRouter, HTTPException, Request
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.get_blog_posts import (
BlogPost,
@@ -17,6 +17,7 @@ from litellm.litellm_core_utils.get_blog_posts import (
from litellm.proxy._types import (
CommonProxyErrors,
)
from litellm.repositories.table_repositories import ClaudeCodePluginRepository
from litellm.types.agents import AgentCard
from litellm.types.mcp import MCPPublicServer
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
@@ -159,14 +160,14 @@ def _load_endpoints() -> List[Dict[str, Any]]:
)
async def public_model_hub():
import litellm
from litellm.proxy.health_endpoints._health_endpoints import (
_convert_health_check_to_dict,
)
from litellm.proxy.proxy_server import (
_get_model_group_info,
llm_router,
prisma_client,
)
from litellm.proxy.health_endpoints._health_endpoints import (
_convert_health_check_to_dict,
)
if llm_router is None:
raise HTTPException(
@@ -266,7 +267,7 @@ async def public_skill_hub():
try:
prisma_client = await _get_prisma_client()
plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many(
plugins = await ClaudeCodePluginRepository(prisma_client).table.find_many(
where={"enabled": True}
)
items = []
+6 -7
View File
@@ -17,16 +17,17 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.proxy._types import *
from litellm.proxy.auth.auth_utils import is_request_body_safe
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth
from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
_safe_get_request_headers,
get_form_data,
)
from litellm.proxy.auth.auth_utils import is_request_body_safe
from litellm.proxy.vector_store_endpoints.utils import (
assert_user_can_access_vector_store_id,
)
from litellm.repositories.table_repositories import ManagedVectorStoresRepository
router = APIRouter()
@@ -230,11 +231,9 @@ async def _save_vector_store_to_db_from_rag_ingest(
try:
# Check if vector store already exists in database
existing_vector_store = (
await prisma_client.db.litellm_managedvectorstorestable.find_unique(
where={"vector_store_id": vector_store_id}
)
)
existing_vector_store = await ManagedVectorStoresRepository(
prisma_client
).table.find_unique(where={"vector_store_id": vector_store_id})
# Only create if it doesn't exist
if existing_vector_store is None:
@@ -289,7 +288,7 @@ async def _save_vector_store_to_db_from_rag_ingest(
# Update the vector store
from litellm.proxy.utils import safe_dumps
await prisma_client.db.litellm_managedvectorstorestable.update(
await ManagedVectorStoresRepository(prisma_client).table.update(
where={"vector_store_id": vector_store_id},
data={"vector_store_metadata": safe_dumps(existing_metadata)},
)
@@ -8,6 +8,7 @@ from typing import List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy.utils import PrismaClient
from litellm.repositories.table_repositories import SearchToolsRepository
from litellm.types.search import SearchTool
@@ -63,16 +64,16 @@ class SearchToolRegistry:
search_tool_info: str = safe_dumps(search_tool.get("search_tool_info", {}))
# Create search tool in DB
created_search_tool = (
await prisma_client.db.litellm_searchtoolstable.create(
data={
"search_tool_name": search_tool_name,
"litellm_params": litellm_params,
"search_tool_info": search_tool_info,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
)
created_search_tool = await SearchToolsRepository(
prisma_client
).table.create(
data={
"search_tool_name": search_tool_name,
"litellm_params": litellm_params,
"search_tool_info": search_tool_info,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
)
# Add search_tool_id to the returned search tool object
@@ -101,15 +102,15 @@ class SearchToolRegistry:
"""
try:
# Get search tool before deletion for response
existing_tool = await prisma_client.db.litellm_searchtoolstable.find_unique(
where={"search_tool_id": search_tool_id}
)
existing_tool = await SearchToolsRepository(
prisma_client
).table.find_unique(where={"search_tool_id": search_tool_id})
if not existing_tool:
raise Exception(f"Search tool with ID {search_tool_id} not found")
# Delete from DB
await prisma_client.db.litellm_searchtoolstable.delete(
await SearchToolsRepository(prisma_client).table.delete(
where={"search_tool_id": search_tool_id}
)
@@ -145,16 +146,16 @@ class SearchToolRegistry:
search_tool_info: str = safe_dumps(search_tool.get("search_tool_info", {}))
# Update in DB
updated_search_tool = (
await prisma_client.db.litellm_searchtoolstable.update(
where={"search_tool_id": search_tool_id},
data={
"search_tool_name": search_tool_name,
"litellm_params": litellm_params,
"search_tool_info": search_tool_info,
"updated_at": datetime.now(timezone.utc),
},
)
updated_search_tool = await SearchToolsRepository(
prisma_client
).table.update(
where={"search_tool_id": search_tool_id},
data={
"search_tool_name": search_tool_name,
"litellm_params": litellm_params,
"search_tool_info": search_tool_info,
"updated_at": datetime.now(timezone.utc),
},
)
# Convert to dict with ISO formatted datetimes
@@ -179,10 +180,10 @@ class SearchToolRegistry:
List of search tool configurations
"""
try:
search_tools_from_db = (
await prisma_client.db.litellm_searchtoolstable.find_many(
order={"created_at": "desc"},
)
search_tools_from_db = await SearchToolsRepository(
prisma_client
).table.find_many(
order={"created_at": "desc"},
)
search_tools: List[SearchTool] = []
@@ -214,7 +215,7 @@ class SearchToolRegistry:
Search tool configuration or None if not found
"""
try:
search_tool = await prisma_client.db.litellm_searchtoolstable.find_unique(
search_tool = await SearchToolsRepository(prisma_client).table.find_unique(
where={"search_tool_id": search_tool_id}
)
@@ -244,7 +245,7 @@ class SearchToolRegistry:
Search tool configuration or None if not found
"""
try:
search_tool = await prisma_client.db.litellm_searchtoolstable.find_unique(
search_tool = await SearchToolsRepository(prisma_client).table.find_unique(
where={"search_tool_name": search_tool_name}
)
@@ -6,11 +6,12 @@ from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
from litellm.repositories.config_repository import ConfigRepository
from litellm.types.proxy.cloudzero_endpoints import (
CloudZeroExportRequest,
CloudZeroExportResponse,
@@ -53,7 +54,7 @@ async def _set_cloudzero_settings(api_key: str, connection_id: str, timezone: st
"timezone": timezone,
}
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": "cloudzero_settings"},
data={
"create": {
@@ -80,7 +81,7 @@ async def _get_cloudzero_settings():
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
cloudzero_config = await prisma_client.db.litellm_config.find_first(
cloudzero_config = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": "cloudzero_settings"}
)
if cloudzero_config is None or cloudzero_config.param_value is None:
@@ -282,7 +283,7 @@ async def is_cloudzero_setup_in_db() -> bool:
return False
# Check for CloudZero settings in database
cloudzero_config = await prisma_client.db.litellm_config.find_first(
cloudzero_config = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": "cloudzero_settings"}
)
@@ -548,7 +549,7 @@ async def delete_cloudzero_settings(
)
# Check if CloudZero settings exist
cloudzero_config = await prisma_client.db.litellm_config.find_first(
cloudzero_config = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": "cloudzero_settings"}
)
@@ -560,7 +561,7 @@ async def delete_cloudzero_settings(
# Delete only the CloudZero settings entry
# This uses a specific where clause to target only the cloudzero_settings row
await prisma_client.db.litellm_config.delete(
await ConfigRepository(prisma_client).table.delete(
where={"param_name": "cloudzero_settings"}
)
@@ -21,6 +21,11 @@ from litellm.proxy.spend_tracking.spend_tracking_utils import (
get_spend_by_team_and_customer,
)
from litellm.proxy.utils import handle_exception_on_proxy
from litellm.repositories.table_repositories import SpendLogsRepository
from litellm.repositories.team_repository import TeamRepository
from litellm.repositories.verification_token_repository import (
VerificationTokenRepository,
)
if TYPE_CHECKING:
from litellm.proxy.proxy_server import PrismaClient
@@ -2010,7 +2015,7 @@ async def ui_view_spend_logs( # noqa: PLR0915
order_direction = (sort_order or "desc").lower()
# Get total count of records
total_records = await prisma_client.db.litellm_spendlogs.count(
total_records = await SpendLogsRepository(prisma_client).table.count(
where=where_conditions,
)
@@ -2374,7 +2379,7 @@ async def view_spend_logs( # noqa: PLR0915
# Check if user wants unsummarized data
if not summarize:
# Return filtered individual log entries (similar to UI endpoint)
data = await prisma_client.db.litellm_spendlogs.find_many(
data = await SpendLogsRepository(prisma_client).table.find_many(
where=filter_query, # type: ignore
order={
"startTime": "desc",
@@ -2384,7 +2389,7 @@ async def view_spend_logs( # noqa: PLR0915
# Legacy behavior: return summarized data (when summarize=true)
# SQL query
response = await prisma_client.db.litellm_spendlogs.group_by(
response = await SpendLogsRepository(prisma_client).table.group_by(
by=["api_key", "user", "model", "startTime"],
where=filter_query, # type: ignore
sum={
@@ -2462,7 +2467,7 @@ async def view_spend_logs( # noqa: PLR0915
)
return spend_logs
data = await prisma_client.db.litellm_spendlogs.find_many(
data = await SpendLogsRepository(prisma_client).table.find_many(
where=scoped_filter, # type: ignore
order={"startTime": "desc"},
)
@@ -2514,10 +2519,10 @@ async def global_spend_reset():
code=status.HTTP_401_UNAUTHORIZED,
)
await prisma_client.db.litellm_verificationtoken.update_many(
await VerificationTokenRepository(prisma_client).table.update_many(
data={"spend": 0.0}, where={}
)
await prisma_client.db.litellm_teamtable.update_many(data={"spend": 0.0}, where={})
await TeamRepository(prisma_client).table.update_many(data={"spend": 0.0}, where={})
return {
"message": "Spend for all API Keys and Teams reset successfully",
@@ -3384,7 +3389,7 @@ async def ui_view_session_spend_logs(
skip = (page - 1) * page_size
# Get total count for pagination metadata
total_records = await prisma_client.db.litellm_spendlogs.count(
total_records = await SpendLogsRepository(prisma_client).table.count(
where=where_conditions
)
@@ -3485,7 +3490,7 @@ async def _build_ui_spend_logs_response(
# is bounded by page_size (typically 25-50 distinct session IDs).
# If performance degrades at scale, consider short-lived caching or
# folding the count into the main query via a window function.
counts = await prisma_client.db.litellm_spendlogs.group_by(
counts = await SpendLogsRepository(prisma_client).table.group_by(
by=["session_id"],
where={"session_id": {"in": session_ids}},
count={"session_id": True},
@@ -3572,7 +3577,7 @@ async def _can_team_member_view_log(
if team_id is None:
return False
team_row = await prisma_client.db.litellm_teamtable.find_unique(
team_row = await TeamRepository(prisma_client).table.find_unique(
where={"team_id": team_id}
)
if team_row is None:
@@ -3614,7 +3619,7 @@ async def _assert_user_can_view_request_id(
permitted teams (admin or ``/spend/logs`` permission).
Raises HTTP 403 if not.
"""
row = await prisma_client.db.litellm_spendlogs.find_unique(
row = await SpendLogsRepository(prisma_client).table.find_unique(
where={"request_id": request_id},
include=None,
)
@@ -3669,7 +3674,7 @@ async def _get_permitted_team_ids_for_spend_logs(
if user_obj is None or not user_obj.teams:
return []
team_rows = await prisma_client.db.litellm_teamtable.find_many(
team_rows = await TeamRepository(prisma_client).table.find_many(
where={"team_id": {"in": user_obj.teams}}
)
@@ -1,17 +1,18 @@
import json
import litellm
from fastapi import APIRouter, Depends, HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
from litellm.repositories.config_repository import ConfigRepository
from litellm.types.proxy.vantage_endpoints import (
VantageDryRunRequest,
VantageExportRequest,
@@ -60,7 +61,7 @@ async def _set_vantage_settings(api_key: str, integration_token: str, base_url:
"base_url": base_url,
}
await prisma_client.db.litellm_config.upsert(
await ConfigRepository(prisma_client).table.upsert(
where={"param_name": VANTAGE_SETTINGS_PARAM_NAME},
data={
"create": {
@@ -82,7 +83,7 @@ async def _get_vantage_settings():
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
vantage_config = await prisma_client.db.litellm_config.find_first(
vantage_config = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": VANTAGE_SETTINGS_PARAM_NAME}
)
if vantage_config is None or vantage_config.param_value is None:
@@ -265,7 +266,7 @@ async def is_vantage_setup_in_db() -> bool:
if prisma_client is None:
return False
vantage_config = await prisma_client.db.litellm_config.find_first(
vantage_config = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": VANTAGE_SETTINGS_PARAM_NAME}
)
@@ -553,7 +554,7 @@ async def delete_vantage_settings(
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
vantage_config = await prisma_client.db.litellm_config.find_first(
vantage_config = await ConfigRepository(prisma_client).table.find_first(
where={"param_name": VANTAGE_SETTINGS_PARAM_NAME}
)
@@ -563,7 +564,7 @@ async def delete_vantage_settings(
detail={"error": "Vantage settings not found"},
)
await prisma_client.db.litellm_config.delete(
await ConfigRepository(prisma_client).table.delete(
where={"param_name": VANTAGE_SETTINGS_PARAM_NAME}
)
@@ -12,6 +12,12 @@ from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.sensitive_data_masker import mask_sensitive_keys
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.repositories.config_repository import ConfigRepository
from litellm.repositories.table_repositories import (
DailyTagSpendRepository,
SSOConfigRepository,
UISettingsRepository,
)
from litellm.types.proxy.management_endpoints.ui_sso import (
DefaultTeamSSOParams,
InProductNudgeResponse,
@@ -665,7 +671,7 @@ async def get_sso_settings():
)
# Get SSO config from dedicated table
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
sso_db_record = await SSOConfigRepository(prisma_client).table.find_unique(
where={"id": "sso_config"}
)
@@ -836,7 +842,7 @@ async def update_sso_settings(sso_config: SSOConfig):
)
# Save to dedicated SSO table
await prisma_client.db.litellm_ssoconfig.upsert(
await SSOConfigRepository(prisma_client).table.upsert(
where={"id": "sso_config"},
data={
"create": {
@@ -851,7 +857,7 @@ async def update_sso_settings(sso_config: SSOConfig):
# Remove SSO-related env vars from config.environment_variables
try:
env_var_entry = await prisma_client.db.litellm_config.find_unique(
env_var_entry = await ConfigRepository(prisma_client).table.find_unique(
where={"param_name": "environment_variables"}
)
@@ -872,7 +878,7 @@ async def update_sso_settings(sso_config: SSOConfig):
if key not in env_vars_to_remove
}
await prisma_client.db.litellm_config.update(
await ConfigRepository(prisma_client).table.update(
where={"param_name": "environment_variables"},
data={
"param_value": json.dumps(filtered_env_vars, default=str),
@@ -1123,7 +1129,7 @@ async def get_in_product_nudges():
detail={"error": "Database not connected. Please connect a database."},
)
db_record = await prisma_client.db.litellm_dailytagspend.find_first(
db_record = await DailyTagSpendRepository(prisma_client).table.find_first(
where={"tag": "User-Agent: claude-cli"}
)
@@ -1155,7 +1161,7 @@ async def get_ui_settings_cached() -> Dict[str, Any]:
if prisma_client is None:
return {}
db_record = await prisma_client.db.litellm_uisettings.find_unique(
db_record = await UISettingsRepository(prisma_client).table.find_unique(
where={"id": "ui_settings"}
)
ui_settings: Dict[str, Any] = {}
@@ -1196,7 +1202,7 @@ async def get_ui_settings():
ui_settings: Dict[str, Any] = {}
db_record = await prisma_client.db.litellm_uisettings.find_unique(
db_record = await UISettingsRepository(prisma_client).table.find_unique(
where={"id": "ui_settings"}
)
@@ -1309,7 +1315,7 @@ async def update_ui_settings(
# Merge with existing persisted settings so a partial PATCH doesn't
# overwrite fields the caller didn't send.
existing: dict = {}
db_existing = await prisma_client.db.litellm_uisettings.find_unique(
db_existing = await UISettingsRepository(prisma_client).table.find_unique(
where={"id": "ui_settings"}
)
if db_existing and db_existing.ui_settings:
@@ -1318,7 +1324,7 @@ async def update_ui_settings(
ui_settings = {**existing, **incoming}
await prisma_client.db.litellm_uisettings.upsert(
await UISettingsRepository(prisma_client).table.upsert(
where={"id": "ui_settings"},
data={
"create": {

Some files were not shown because too many files have changed in this diff Show More