mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 02:48:35 +00:00
feat(litellm): add models and repository layers (#29686)
This commit is contained in:
@@ -28,6 +28,8 @@ jobs:
|
||||
tests/test_litellm/completion_extras
|
||||
tests/test_litellm/containers
|
||||
tests/test_litellm/experimental_mcp_client
|
||||
tests/test_litellm/models
|
||||
tests/test_litellm/repositories
|
||||
tests/test_litellm/images
|
||||
tests/test_litellm/interactions
|
||||
tests/test_litellm/passthrough
|
||||
|
||||
@@ -240,6 +240,24 @@ graph LR
|
||||
7. `DBSpendUpdateWriter.update_database()` queues spend increments to Redis
|
||||
8. Background job `update_spend` flushes queued spend to PostgreSQL every 60s
|
||||
|
||||
### Data Access Layer (Models & Repositories)
|
||||
|
||||
Database entities and the operations on them live in two packages at the root of `litellm/` so both the gateway (`proxy/`) and the SDK can use them without importing proxy internals:
|
||||
|
||||
- `litellm/models/` holds the canonical Pydantic definitions for every persisted entity (`LiteLLM_VerificationToken`, `LiteLLM_TeamTable`, `LiteLLM_UserTable`, etc.). `proxy/_types.py` re-exports these for backwards compatibility, so existing imports keep working.
|
||||
- `litellm/repositories/` holds the data-access layer. `BaseRepository[T]` provides the generic CRUD (`find_by_id`, `find_many`, `create`, `update`, `delete`, `count`, `exists`); entity repositories such as `VerificationTokenRepository`, `TeamRepository`, and `UserRepository` add domain-specific queries and writes on top of it.
|
||||
|
||||
Conventions to follow when touching this layer:
|
||||
|
||||
| Concern | How it's handled |
|
||||
|---------|------------------|
|
||||
| JSON columns | Prisma `Json` columns are stored as JSON strings. Repositories `json.dumps()` on write and `json.loads()` on read (see `_to_model` and the `_build_*_data` helpers). |
|
||||
| Archive-then-delete | `delete_team` / `delete_token` copy the row into the `LiteLLM_Deleted*` table and delete the original inside a single `prisma_client.db.tx()` transaction. Archive payloads are built explicitly so only columns that exist on the archive table are written. |
|
||||
| Column vs. field names | Where a model field differs from its DB column (for example `org_id` maps to the `organization_id` column), the repository translates in both directions rather than relying on Pydantic to guess. |
|
||||
| Array mutations | Adds use Prisma's atomic `push` (`add_member`, `add_admin`, `add_models`) to avoid read-modify-write races. Removals fall back to read-modify-write because Prisma has no atomic array remove. |
|
||||
|
||||
To add a new entity, define the model under `litellm/models/`, re-export it from `proxy/_types.py` if existing code imports it from there, and add a repository under `litellm/repositories/` (subclass `BaseRepository` for plain CRUD, or add bespoke methods when the entity needs encryption, archiving, or atomic array updates). Mirror the tests in `tests/test_litellm/repositories/`.
|
||||
|
||||
---
|
||||
|
||||
## 2. SDK Request Flow
|
||||
|
||||
@@ -37,6 +37,8 @@ from litellm.proxy._types import (
|
||||
VirtualKeyEvent,
|
||||
WebhookEvent,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.types.integrations.slack_alerting import *
|
||||
|
||||
from ..email_templates.templates import *
|
||||
@@ -1231,7 +1233,7 @@ Model Info:
|
||||
and recipient_user_id is not None
|
||||
and prisma_client is not None
|
||||
):
|
||||
user_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_row = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": recipient_user_id}
|
||||
)
|
||||
|
||||
@@ -1263,7 +1265,7 @@ Model Info:
|
||||
team_id = webhook_event.team_id
|
||||
team_name = "Default Team"
|
||||
if team_id is not None and prisma_client is not None:
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team_row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
if team_row is not None:
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.proxy._types import WebhookEvent
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
|
||||
# we use this for the email header, please send a test email if you change this. verify it looks good on email
|
||||
LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||
@@ -24,7 +25,7 @@ async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
|
||||
if prisma_client is None:
|
||||
raise Exception("Not connected to DB!")
|
||||
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team_row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={
|
||||
"team_id": team_id,
|
||||
}
|
||||
|
||||
@@ -29,13 +29,13 @@ from litellm.exceptions import (
|
||||
validate_rate_limit_type,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.prometheus_helpers.bounded_prometheus_series_tracker import (
|
||||
BoundedPrometheusSeriesTracker,
|
||||
)
|
||||
from litellm.integrations.prometheus_helpers import (
|
||||
PrometheusLabelFactoryContext,
|
||||
_get_cached_end_user_id_for_cost_tracking,
|
||||
)
|
||||
from litellm.integrations.prometheus_helpers.bounded_prometheus_series_tracker import (
|
||||
BoundedPrometheusSeriesTracker,
|
||||
)
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
get_litellm_metadata_from_kwargs,
|
||||
get_metadata_variable_name_from_kwargs,
|
||||
@@ -46,6 +46,9 @@ from litellm.proxy._types import (
|
||||
LiteLLM_UserTable,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.repositories.organization_repository import OrganizationRepository
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.types.integrations.prometheus import *
|
||||
from litellm.types.integrations.prometheus import (
|
||||
_sanitize_prometheus_label_name,
|
||||
@@ -3278,12 +3281,12 @@ class PrometheusLogger(CustomLogger):
|
||||
page_size: int, page: int
|
||||
) -> Tuple[List[LiteLLM_UserTable], Optional[int]]:
|
||||
skip = (page - 1) * page_size
|
||||
users = await prisma_client.db.litellm_usertable.find_many(
|
||||
users = await UserRepository(prisma_client).table.find_many(
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
total_count = await prisma_client.db.litellm_usertable.count()
|
||||
total_count = await UserRepository(prisma_client).table.count()
|
||||
return users, total_count
|
||||
|
||||
await self._initialize_budget_metrics(
|
||||
@@ -3306,13 +3309,13 @@ class PrometheusLogger(CustomLogger):
|
||||
|
||||
async def fetch_orgs(page_size: int, page: int) -> Tuple[list, Optional[int]]:
|
||||
skip = (page - 1) * page_size
|
||||
orgs = await prisma_client.db.litellm_organizationtable.find_many(
|
||||
orgs = await OrganizationRepository(prisma_client).table.find_many(
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
order={"created_at": "desc"},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
total_count = await prisma_client.db.litellm_organizationtable.count()
|
||||
total_count = await OrganizationRepository(prisma_client).table.count()
|
||||
return orgs, total_count
|
||||
|
||||
await self._initialize_budget_metrics(
|
||||
@@ -3380,14 +3383,14 @@ class PrometheusLogger(CustomLogger):
|
||||
|
||||
try:
|
||||
# Get total user count
|
||||
total_users = await prisma_client.db.litellm_usertable.count()
|
||||
total_users = await UserRepository(prisma_client).table.count()
|
||||
self.litellm_total_users_metric.set(total_users)
|
||||
verbose_logger.debug(
|
||||
f"Prometheus: set litellm_total_users to {total_users}"
|
||||
)
|
||||
|
||||
# Get total team count
|
||||
total_teams = await prisma_client.db.litellm_teamtable.count()
|
||||
total_teams = await TeamRepository(prisma_client).table.count()
|
||||
self.litellm_teams_count_metric.set(total_teams)
|
||||
verbose_logger.debug(
|
||||
f"Prometheus: set litellm_teams_count to {total_teams}"
|
||||
|
||||
@@ -17,6 +17,7 @@ from litellm.proxy.common_utils.resource_ownership import (
|
||||
is_proxy_admin,
|
||||
user_can_access_resource_owner,
|
||||
)
|
||||
from litellm.repositories.table_repositories import SkillsRepository
|
||||
|
||||
# Skills are looked up on every chat completion that has skills enabled
|
||||
# (`SkillsInjectionHook` calls ``fetch_skill_from_db``). 60s LRU/TTL cache
|
||||
@@ -107,7 +108,7 @@ class LiteLLMSkillsHandler:
|
||||
f"LiteLLMSkillsHandler: Creating skill {skill_id} with title={data.display_title}"
|
||||
)
|
||||
|
||||
new_skill = await prisma_client.db.litellm_skillstable.create(data=skill_data)
|
||||
new_skill = await SkillsRepository(prisma_client).table.create(data=skill_data)
|
||||
return _prisma_skill_to_litellm(new_skill)
|
||||
|
||||
@staticmethod
|
||||
@@ -133,7 +134,7 @@ class LiteLLMSkillsHandler:
|
||||
return []
|
||||
find_many_kwargs["where"] = {"created_by": {"in": owner_scopes}}
|
||||
|
||||
skills = await prisma_client.db.litellm_skillstable.find_many(
|
||||
skills = await SkillsRepository(prisma_client).table.find_many(
|
||||
**find_many_kwargs
|
||||
)
|
||||
return [_prisma_skill_to_litellm(s) for s in skills]
|
||||
@@ -150,7 +151,7 @@ class LiteLLMSkillsHandler:
|
||||
return cached
|
||||
|
||||
prisma_client = await LiteLLMSkillsHandler._get_prisma_client()
|
||||
skill = await prisma_client.db.litellm_skillstable.find_unique(
|
||||
skill = await SkillsRepository(prisma_client).table.find_unique(
|
||||
where={"skill_id": skill_id}
|
||||
)
|
||||
_SKILL_CACHE.set_cache(
|
||||
@@ -189,7 +190,7 @@ class LiteLLMSkillsHandler:
|
||||
):
|
||||
raise ValueError(f"Skill not found: {skill_id}")
|
||||
|
||||
await prisma_client.db.litellm_skillstable.delete(where={"skill_id": skill_id})
|
||||
await SkillsRepository(prisma_client).table.delete(where={"skill_id": skill_id})
|
||||
_SKILL_CACHE.set_cache(skill_id, _NEGATIVE_SKILL_SENTINEL)
|
||||
|
||||
return {"id": skill_id, "type": "skill_deleted"}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Domain models for LiteLLM backend.
|
||||
"""
|
||||
|
||||
from litellm.models.access_group import LiteLLM_AccessGroupTable
|
||||
from litellm.models.budget import (
|
||||
LiteLLM_BudgetTable,
|
||||
LiteLLM_BudgetTableFull,
|
||||
LiteLLM_TeamMemberTable,
|
||||
)
|
||||
from litellm.models.config import LiteLLM_Config
|
||||
from litellm.models.credentials import (
|
||||
CreateCredentialItem,
|
||||
CredentialBase,
|
||||
CredentialItem,
|
||||
)
|
||||
from litellm.models.end_user import LiteLLM_EndUserTable
|
||||
from litellm.models.managed_files import (
|
||||
LiteLLM_ManagedFileTable,
|
||||
LiteLLM_ManagedObjectTable,
|
||||
LiteLLM_ManagedVectorStoresTable,
|
||||
LiteLLM_ManagedVectorStoreTable,
|
||||
)
|
||||
from litellm.models.mcp_server import LiteLLM_MCPServerTable
|
||||
from litellm.models.model import LiteLLM_ProxyModelTable
|
||||
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
|
||||
from litellm.models.organization import LiteLLM_OrganizationTable
|
||||
from litellm.models.organization_membership import LiteLLM_OrganizationMembershipTable
|
||||
from litellm.models.project import LiteLLM_ProjectTable
|
||||
from litellm.models.skills import LiteLLM_SkillsTable
|
||||
from litellm.models.spend_logs import LiteLLM_ErrorLogs, LiteLLM_SpendLogs
|
||||
from litellm.models.tag import LiteLLM_TagTable
|
||||
from litellm.models.team import LiteLLM_TeamTable
|
||||
from litellm.models.team_membership import LiteLLM_TeamMembership
|
||||
from litellm.models.user import LiteLLM_UserTable
|
||||
from litellm.models.verification_token import LiteLLM_VerificationToken
|
||||
|
||||
__all__ = [
|
||||
"LiteLLM_AccessGroupTable",
|
||||
"LiteLLM_BudgetTable",
|
||||
"LiteLLM_BudgetTableFull",
|
||||
"LiteLLM_TeamMemberTable",
|
||||
"LiteLLM_Config",
|
||||
"CredentialBase",
|
||||
"CredentialItem",
|
||||
"CreateCredentialItem",
|
||||
"LiteLLM_EndUserTable",
|
||||
"LiteLLM_ManagedFileTable",
|
||||
"LiteLLM_ManagedObjectTable",
|
||||
"LiteLLM_ManagedVectorStoreTable",
|
||||
"LiteLLM_ManagedVectorStoresTable",
|
||||
"LiteLLM_MCPServerTable",
|
||||
"LiteLLM_ProxyModelTable",
|
||||
"LiteLLM_ObjectPermissionTable",
|
||||
"LiteLLM_OrganizationTable",
|
||||
"LiteLLM_OrganizationMembershipTable",
|
||||
"LiteLLM_ProjectTable",
|
||||
"LiteLLM_SkillsTable",
|
||||
"LiteLLM_ErrorLogs",
|
||||
"LiteLLM_SpendLogs",
|
||||
"LiteLLM_TagTable",
|
||||
"LiteLLM_TeamTable",
|
||||
"LiteLLM_TeamMembership",
|
||||
"LiteLLM_UserTable",
|
||||
"LiteLLM_VerificationToken",
|
||||
]
|
||||
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Access group table model.
|
||||
|
||||
Canonical definition for ``litellm_accessgrouptable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_AccessGroupTable(LiteLLMPydanticObjectBase):
|
||||
access_group_id: str
|
||||
access_group_name: str
|
||||
description: Optional[str] = None
|
||||
access_model_names: List[str] = []
|
||||
access_mcp_server_ids: List[str] = []
|
||||
access_agent_ids: List[str] = []
|
||||
assigned_team_ids: List[str] = []
|
||||
assigned_key_ids: List[str] = []
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_by: Optional[str] = None
|
||||
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Base model class for domain models.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class DomainModel(BaseModel):
|
||||
"""Base class for all domain models."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
protected_namespaces=(),
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
@classmethod
|
||||
def from_db_record(cls, record: Any) -> "DomainModel":
|
||||
"""Create a domain model from a database record."""
|
||||
if record is None:
|
||||
raise ValueError("Cannot create domain model from None record")
|
||||
if isinstance(record, dict):
|
||||
return cls(**record)
|
||||
if hasattr(record, "model_dump") and callable(record.model_dump):
|
||||
return cls(**record.model_dump())
|
||||
if hasattr(record, "dict") and callable(record.dict):
|
||||
return cls(**record.dict())
|
||||
return cls(**dict(record))
|
||||
|
||||
def to_db_dict(self, exclude_unset: bool = False) -> Dict[str, Any]:
|
||||
"""Convert domain model to a dictionary for database operations."""
|
||||
return self.model_dump(exclude_none=True, exclude_unset=exclude_unset)
|
||||
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Budget table model.
|
||||
|
||||
Canonical definition for ``litellm_budgettable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase):
|
||||
"""Represents user-controllable params for a LiteLLM_BudgetTable record.
|
||||
|
||||
Budget-write paths use `model_fields.keys()` on this class as an allowlist
|
||||
for user input. Keep server-managed fields (e.g. `budget_reset_at`) on
|
||||
`LiteLLM_BudgetTableFull` so they aren't user-settable.
|
||||
"""
|
||||
|
||||
budget_id: Optional[str] = None
|
||||
soft_budget: Optional[float] = None
|
||||
max_budget: Optional[float] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
model_max_budget: Optional[dict] = None
|
||||
budget_duration: Optional[str] = None
|
||||
allowed_models: Optional[List[str]] = (
|
||||
None # per-member model scope; empty = inherit team models
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable):
|
||||
"""LiteLLM_BudgetTable + server-managed fields returned on API responses."""
|
||||
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
|
||||
"""
|
||||
Used to track spend of a user_id within a team_id
|
||||
"""
|
||||
|
||||
spend: Optional[float] = None
|
||||
user_id: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
budget_id: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Config table model.
|
||||
|
||||
Canonical definition for ``litellm_config``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_Config(LiteLLMPydanticObjectBase):
|
||||
param_name: str
|
||||
param_value: Dict
|
||||
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Credential table models.
|
||||
|
||||
These are the canonical credential types for the proxy. They live in the model
|
||||
layer; ``litellm.types.utils`` re-exports them for backwards compatibility.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class CredentialBase(BaseModel):
|
||||
credential_name: str
|
||||
credential_info: dict
|
||||
|
||||
|
||||
class CredentialItem(CredentialBase):
|
||||
credential_values: dict
|
||||
|
||||
|
||||
class CreateCredentialItem(CredentialBase):
|
||||
credential_values: Optional[dict] = None
|
||||
model_id: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_credential_params(cls, values):
|
||||
if not values.get("credential_values") and not values.get("model_id"):
|
||||
raise ValueError("Either credential_values or model_id must be set")
|
||||
return values
|
||||
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
End-user table model.
|
||||
|
||||
Canonical definition for ``litellm_endusertable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
from litellm.models.budget import LiteLLM_BudgetTable
|
||||
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase):
|
||||
user_id: str
|
||||
blocked: bool
|
||||
alias: Optional[str] = None
|
||||
spend: float = 0.0
|
||||
allowed_model_region: Optional[Literal["eu", "us"]] = None
|
||||
default_model: Optional[str] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
object_permission_id: Optional[str] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_model_info(cls, values):
|
||||
if values.get("spend") is None:
|
||||
values.update({"spend": 0.0})
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Managed file, object, and vector store table models.
|
||||
|
||||
Canonical definitions for the ``litellm_managed*`` tables. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
from litellm.types.llms.openai import OpenAIFileObject, ResponsesAPIResponse
|
||||
from litellm.types.utils import LiteLLMBatch, LiteLLMFineTuningJob
|
||||
|
||||
|
||||
class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase):
|
||||
unified_file_id: str
|
||||
file_object: Optional[OpenAIFileObject] = None
|
||||
model_mappings: Dict[str, str]
|
||||
flat_model_file_ids: List[str]
|
||||
created_by: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
updated_by: Optional[str] = None
|
||||
storage_backend: Optional[str] = None
|
||||
storage_url: Optional[str] = None
|
||||
|
||||
|
||||
class LiteLLM_ManagedObjectTable(LiteLLMPydanticObjectBase):
|
||||
unified_object_id: str
|
||||
model_object_id: str
|
||||
file_purpose: Literal["batch", "fine-tune", "response", "container"]
|
||||
file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob, ResponsesAPIResponse]
|
||||
created_by: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
|
||||
|
||||
class LiteLLM_ManagedVectorStoreTable(LiteLLMPydanticObjectBase):
|
||||
"""Table for managing vector stores with target_model_names support."""
|
||||
|
||||
unified_resource_id: str
|
||||
resource_object: Optional[Any] = None
|
||||
model_mappings: Dict[str, str]
|
||||
flat_model_resource_ids: List[str]
|
||||
created_by: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
updated_by: Optional[str] = None
|
||||
storage_backend: Optional[str] = None
|
||||
storage_url: Optional[str] = None
|
||||
|
||||
|
||||
class LiteLLM_ManagedVectorStoresTable(LiteLLMPydanticObjectBase):
|
||||
vector_store_id: str
|
||||
custom_llm_provider: str
|
||||
vector_store_name: Optional[str]
|
||||
vector_store_description: Optional[str]
|
||||
vector_store_metadata: Optional[Dict[str, Any]]
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
litellm_credential_name: Optional[str]
|
||||
litellm_params: Optional[Dict[str, Any]]
|
||||
team_id: Optional[str]
|
||||
user_id: Optional[str]
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
MCP server table model.
|
||||
|
||||
Canonical definition for ``litellm_mcpservertable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
from litellm.types.mcp import MCPAuthType, MCPCredentials, MCPTransportType
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
|
||||
|
||||
|
||||
class MCPEnvVarScope(str, enum.Enum):
|
||||
"""Scope for an MCP server environment variable.
|
||||
|
||||
- ``global``: value is provided by the admin and used for all users.
|
||||
- ``user``: each user must provide their own value via the per-user
|
||||
env-var endpoint. The admin-supplied ``value`` is treated as a
|
||||
placeholder/hint and is not used at request time.
|
||||
"""
|
||||
|
||||
global_ = "global"
|
||||
user = "user"
|
||||
|
||||
|
||||
class MCPEnvVar(LiteLLMPydanticObjectBase):
|
||||
"""One environment variable for an MCP server.
|
||||
|
||||
Variables can be interpolated into ``static_headers`` using ``${NAME}``
|
||||
syntax. ``scope=global`` values are stored on the server. ``scope=user``
|
||||
values are stored per-user in ``LiteLLM_MCPUserEnvVars`` and supplied by
|
||||
each user.
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str = ""
|
||||
scope: MCPEnvVarScope = MCPEnvVarScope.global_
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase):
|
||||
"""Represents a LiteLLM_MCPServerTable record"""
|
||||
|
||||
server_id: str
|
||||
server_name: Optional[str] = None
|
||||
alias: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
spec_path: Optional[str] = None
|
||||
transport: MCPTransportType
|
||||
auth_type: Optional[MCPAuthType] = None
|
||||
credentials: Optional[MCPCredentials] = None
|
||||
instructions: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_by: Optional[str] = None
|
||||
teams: List[Dict[str, Optional[str]]] = Field(default_factory=list)
|
||||
mcp_access_groups: List[str] = Field(default_factory=list)
|
||||
allowed_tools: List[str] = Field(default_factory=list)
|
||||
tool_name_to_display_name: Optional[Dict[str, str]] = None
|
||||
tool_name_to_description: Optional[Dict[str, str]] = None
|
||||
extra_headers: List[str] = Field(default_factory=list)
|
||||
mcp_info: Optional[MCPInfo] = None
|
||||
static_headers: Optional[Dict[str, str]] = None
|
||||
env_vars: Optional[List[MCPEnvVar]] = None
|
||||
status: Optional[Literal["healthy", "unhealthy", "unknown"]] = Field(
|
||||
default="unknown",
|
||||
description="Health status: 'healthy', 'unhealthy', 'unknown'",
|
||||
)
|
||||
last_health_check: Optional[datetime] = None
|
||||
health_check_error: Optional[str] = None
|
||||
command: Optional[str] = None
|
||||
args: List[str] = Field(default_factory=list)
|
||||
env: Dict[str, str] = Field(default_factory=dict)
|
||||
authorization_url: Optional[str] = None
|
||||
token_url: Optional[str] = None
|
||||
registration_url: Optional[str] = None
|
||||
oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = None
|
||||
allow_all_keys: bool = False
|
||||
available_on_public_internet: bool = True
|
||||
delegate_auth_to_upstream: bool = False
|
||||
oauth_passthrough: bool = False
|
||||
is_byok: bool = False
|
||||
byok_description: List[str] = Field(default_factory=list)
|
||||
byok_api_key_help_url: Optional[str] = None
|
||||
has_user_credential: Optional[bool] = None
|
||||
source_url: Optional[str] = None
|
||||
timeout: Optional[float] = None
|
||||
approval_status: Optional[str] = Field(
|
||||
default="active",
|
||||
description="Approval status: 'pending_review', 'active', 'rejected'",
|
||||
)
|
||||
submitted_by: Optional[str] = None
|
||||
submitted_at: Optional[datetime] = None
|
||||
reviewed_at: Optional[datetime] = None
|
||||
review_notes: Optional[str] = None
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Proxy model table model.
|
||||
|
||||
Canonical definition for ``litellm_proxymodeltable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_ProxyModelTable(LiteLLMPydanticObjectBase):
|
||||
model_id: str
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
model_info: Optional[dict] = None
|
||||
blocked: bool = False
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_by: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_potential_json_str(cls, values):
|
||||
if isinstance(values.get("litellm_params"), str):
|
||||
try:
|
||||
values["litellm_params"] = json.loads(values["litellm_params"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if isinstance(values.get("model_info"), str):
|
||||
try:
|
||||
values["model_info"] = json.loads(values["model_info"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return values
|
||||
|
||||
@property
|
||||
def is_blocked(self) -> bool:
|
||||
return self.blocked
|
||||
|
||||
@property
|
||||
def team_id(self) -> Optional[str]:
|
||||
if self.model_info:
|
||||
return self.model_info.get("team_id")
|
||||
return None
|
||||
|
||||
@property
|
||||
def team_public_model_name(self) -> Optional[str]:
|
||||
if self.model_info:
|
||||
return self.model_info.get("team_public_model_name")
|
||||
return None
|
||||
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Object permission table model.
|
||||
|
||||
Canonical definition for ``litellm_objectpermissiontable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_ObjectPermissionTable(LiteLLMPydanticObjectBase):
|
||||
"""Represents a LiteLLM_ObjectPermissionTable record"""
|
||||
|
||||
object_permission_id: str
|
||||
mcp_servers: Optional[List[str]] = []
|
||||
mcp_access_groups: Optional[List[str]] = []
|
||||
mcp_tool_permissions: Optional[Dict[str, List[str]]] = None
|
||||
vector_stores: Optional[List[str]] = []
|
||||
agents: Optional[List[str]] = []
|
||||
agent_access_groups: Optional[List[str]] = []
|
||||
models: Optional[List[str]] = []
|
||||
mcp_toolsets: Optional[List[str]] = None
|
||||
blocked_tools: Optional[List[str]] = []
|
||||
search_tools: Optional[List[str]] = []
|
||||
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Organization table model.
|
||||
|
||||
Canonical definition for ``litellm_organizationtable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.models.budget import LiteLLM_BudgetTable
|
||||
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
|
||||
from litellm.models.user import LiteLLM_UserTable
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_OrganizationTable(LiteLLMPydanticObjectBase):
|
||||
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
|
||||
|
||||
organization_id: Optional[str] = None
|
||||
organization_alias: Optional[str] = None
|
||||
budget_id: str
|
||||
spend: float = 0.0
|
||||
metadata: Optional[dict] = None
|
||||
models: List[str] = []
|
||||
model_spend: Optional[dict] = {}
|
||||
created_by: str
|
||||
updated_by: str
|
||||
users: Optional[List[LiteLLM_UserTable]] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
object_permission_id: Optional[str] = None
|
||||
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Organization membership table model.
|
||||
|
||||
Canonical definition for ``litellm_organizationmembership``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
from litellm.models.budget import LiteLLM_BudgetTable
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
|
||||
"""Tracks which organizations a user belongs to and their spend within it."""
|
||||
|
||||
user_id: str
|
||||
organization_id: str
|
||||
user_role: Optional[str] = None
|
||||
spend: float = 0.0
|
||||
budget_id: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
user: Optional[Any] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
user_email: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@model_validator(mode="after")
|
||||
def populate_user_email(self) -> "LiteLLM_OrganizationMembershipTable":
|
||||
if self.user_email is None and self.user is not None:
|
||||
if isinstance(self.user, dict):
|
||||
self.user_email = self.user.get("user_email")
|
||||
else:
|
||||
self.user_email = getattr(self.user, "user_email", None)
|
||||
return self
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Project table model.
|
||||
|
||||
Canonical definition for ``litellm_projecttable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.models.budget import LiteLLM_BudgetTable
|
||||
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_ProjectTable(LiteLLMPydanticObjectBase):
|
||||
"""Database model representation for project"""
|
||||
|
||||
project_id: str
|
||||
project_alias: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
budget_id: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
models: List[str] = []
|
||||
spend: float = 0.0
|
||||
model_spend: Optional[dict] = None
|
||||
model_rpm_limit: Optional[dict] = None
|
||||
model_tpm_limit: Optional[dict] = None
|
||||
blocked: bool = False
|
||||
object_permission_id: Optional[str] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_by: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
|
||||
@property
|
||||
def is_blocked(self) -> bool:
|
||||
return self.blocked
|
||||
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
Skills table model.
|
||||
|
||||
Canonical definition for ``litellm_skillstable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_SkillsTable(LiteLLMPydanticObjectBase):
|
||||
"""Represents a LiteLLM_SkillsTable record"""
|
||||
|
||||
skill_id: str
|
||||
display_title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
source: str = "custom"
|
||||
latest_version: Optional[str] = None
|
||||
file_content: Optional[bytes] = None
|
||||
file_name: Optional[str] = None
|
||||
file_type: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_by: Optional[str] = None
|
||||
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Spend and error log table models.
|
||||
|
||||
Canonical definitions for ``litellm_spendlogs`` and ``litellm_errorlogs``.
|
||||
Re-exported from ``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Json
|
||||
|
||||
from litellm._uuid import uuid
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_SpendLogs(LiteLLMPydanticObjectBase):
|
||||
request_id: str
|
||||
api_key: str
|
||||
model: Optional[str] = ""
|
||||
api_base: Optional[str] = ""
|
||||
call_type: str
|
||||
spend: Optional[float] = 0.0
|
||||
total_tokens: Optional[int] = 0
|
||||
prompt_tokens: Optional[int] = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
startTime: Union[str, datetime, None]
|
||||
endTime: Union[str, datetime, None]
|
||||
user: Optional[str] = ""
|
||||
metadata: Optional[Json] = {}
|
||||
cache_hit: Optional[str] = "False"
|
||||
cache_key: Optional[str] = None
|
||||
request_tags: Optional[Json] = None
|
||||
requester_ip_address: Optional[str] = None
|
||||
messages: Optional[Union[str, list, dict]]
|
||||
response: Optional[Union[str, list, dict]]
|
||||
|
||||
|
||||
class LiteLLM_ErrorLogs(LiteLLMPydanticObjectBase):
|
||||
request_id: Optional[str] = str(uuid.uuid4())
|
||||
api_base: Optional[str] = ""
|
||||
model_group: Optional[str] = ""
|
||||
litellm_model_name: Optional[str] = ""
|
||||
model_id: Optional[str] = ""
|
||||
request_kwargs: Optional[dict] = {}
|
||||
exception_type: Optional[str] = ""
|
||||
status_code: Optional[str] = ""
|
||||
exception_string: Optional[str] = ""
|
||||
startTime: Union[str, datetime, None]
|
||||
endTime: Union[str, datetime, None]
|
||||
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Tag table model.
|
||||
|
||||
Canonical definition for ``litellm_tagtable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from litellm.models.budget import LiteLLM_BudgetTable
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_TagTable(LiteLLMPydanticObjectBase):
|
||||
tag_name: str
|
||||
description: Optional[str] = None
|
||||
models: List[str] = []
|
||||
model_info: Optional[dict] = None
|
||||
spend: float = 0.0
|
||||
budget_id: Optional[str] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_model_info(cls, values):
|
||||
if values.get("spend") is None:
|
||||
values.update({"spend": 0.0})
|
||||
if values.get("models") is None:
|
||||
values.update({"models": []})
|
||||
return values
|
||||
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Team table models.
|
||||
|
||||
Canonical definitions for ``litellm_teamtable`` (plus the shared Member and
|
||||
budget-window value types and the team-model alias table). Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class MemberBase(LiteLLMPydanticObjectBase):
|
||||
user_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique ID of the user to add. Either user_id or user_email must be provided",
|
||||
)
|
||||
user_email: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The email address of the user to add. Either user_id or user_email must be provided",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_user_info(cls, values):
|
||||
if not isinstance(values, dict):
|
||||
raise ValueError("input needs to be a dictionary")
|
||||
if values.get("user_id") is None and values.get("user_email") is None:
|
||||
raise ValueError("Either user id or user email must be provided")
|
||||
return values
|
||||
|
||||
|
||||
class Member(MemberBase):
|
||||
role: Literal["admin", "user"] = Field(
|
||||
description="The role of the user within the team. 'admin' users can manage team settings and members, 'user' is a regular team member"
|
||||
)
|
||||
|
||||
|
||||
class BudgetLimitEntry(LiteLLMPydanticObjectBase):
|
||||
"""A single budget window with its own limit and independent reset schedule."""
|
||||
|
||||
budget_duration: str
|
||||
max_budget: float
|
||||
reset_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class LiteLLM_ModelTable(LiteLLMPydanticObjectBase):
|
||||
id: Optional[int] = None
|
||||
model_aliases: Optional[Union[str, dict]] = None
|
||||
created_by: str
|
||||
updated_by: str
|
||||
team: Optional["LiteLLM_TeamTable"] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class TeamBase(LiteLLMPydanticObjectBase):
|
||||
team_alias: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
admins: list = []
|
||||
members: list = []
|
||||
members_with_roles: List[Member] = []
|
||||
team_member_permissions: Optional[List[str]] = None
|
||||
metadata: Optional[dict] = None
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
max_budget: Optional[float] = None
|
||||
soft_budget: Optional[float] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_limits: Optional[List[BudgetLimitEntry]] = None
|
||||
models: list = []
|
||||
blocked: bool = False
|
||||
router_settings: Optional[dict] = None
|
||||
access_group_ids: Optional[List[str]] = None
|
||||
default_team_member_models: Optional[List[str]] = None
|
||||
|
||||
|
||||
class LiteLLM_TeamTable(TeamBase):
|
||||
team_id: str # type: ignore
|
||||
spend: Optional[float] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
model_id: Optional[int] = None
|
||||
model_spend: Optional[dict] = {}
|
||||
model_max_budget: Optional[dict] = {}
|
||||
policies: Optional[List[str]] = None
|
||||
allow_team_guardrail_config: Optional[bool] = False
|
||||
litellm_model_table: Optional[LiteLLM_ModelTable] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
object_permission_id: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_model_info(cls, values):
|
||||
dict_fields = [
|
||||
"metadata",
|
||||
"aliases",
|
||||
"config",
|
||||
"permissions",
|
||||
"model_max_budget",
|
||||
"model_aliases",
|
||||
"router_settings",
|
||||
"budget_limits",
|
||||
]
|
||||
|
||||
if isinstance(values, BaseModel):
|
||||
values = values.model_dump()
|
||||
|
||||
if (
|
||||
isinstance(values.get("members_with_roles"), dict)
|
||||
and not values["members_with_roles"]
|
||||
):
|
||||
values["members_with_roles"] = []
|
||||
|
||||
for field in dict_fields:
|
||||
value = values.get(field)
|
||||
if value is not None and isinstance(value, str):
|
||||
try:
|
||||
values[field] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Field {field} should be a valid dictionary")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class LiteLLM_TeamTableCachedObj(LiteLLM_TeamTable):
|
||||
last_refreshed_at: Optional[float] = None
|
||||
|
||||
|
||||
class LiteLLM_DeletedTeamTable(LiteLLM_TeamTable):
|
||||
"""Audit record for deleted teams; mirrors the team plus deletion metadata."""
|
||||
|
||||
id: Optional[str] = None
|
||||
deleted_at: Optional[datetime] = None
|
||||
deleted_by: Optional[str] = None
|
||||
deleted_by_api_key: Optional[str] = None
|
||||
litellm_changed_by: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
LiteLLM_ModelTable.model_rebuild()
|
||||
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Team membership table model.
|
||||
|
||||
Canonical definition for ``litellm_teammembership``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from litellm.models.budget import LiteLLM_BudgetTable, LiteLLM_BudgetTableFull
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase):
|
||||
user_id: str
|
||||
team_id: str
|
||||
budget_id: Optional[str] = None
|
||||
spend: Optional[float] = 0.0
|
||||
total_spend: Optional[float] = 0.0
|
||||
litellm_budget_table: Optional[
|
||||
Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable]
|
||||
] = None
|
||||
|
||||
def safe_get_team_member_rpm_limit(self) -> Optional[int]:
|
||||
if self.litellm_budget_table is not None:
|
||||
return self.litellm_budget_table.rpm_limit
|
||||
return None
|
||||
|
||||
def safe_get_team_member_tpm_limit(self) -> Optional[int]:
|
||||
if self.litellm_budget_table is not None:
|
||||
return self.litellm_budget_table.tpm_limit
|
||||
return None
|
||||
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
User table model.
|
||||
|
||||
Canonical definition for ``litellm_usertable``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
|
||||
from litellm.models.organization_membership import (
|
||||
LiteLLM_OrganizationMembershipTable,
|
||||
)
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
|
||||
user_id: str
|
||||
user_alias: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
sso_user_id: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
object_permission_id: Optional[str] = None
|
||||
password: Optional[str] = Field(default=None, exclude=True)
|
||||
teams: List[str] = []
|
||||
user_role: Optional[str] = None
|
||||
max_budget: Optional[float] = None
|
||||
spend: float = 0.0
|
||||
user_email: Optional[str] = None
|
||||
models: list = []
|
||||
metadata: Optional[dict] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
allowed_cache_controls: List[str] = []
|
||||
policies: List[str] = []
|
||||
model_spend: Optional[Dict] = {}
|
||||
model_max_budget: Optional[Dict] = {}
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_model_info(cls, values):
|
||||
if values.get("spend") is None:
|
||||
values.update({"spend": 0.0})
|
||||
if values.get("models") is None:
|
||||
values.update({"models": []})
|
||||
if values.get("teams") is None:
|
||||
values.update({"teams": []})
|
||||
return values
|
||||
|
||||
def is_over_budget(self) -> bool:
|
||||
if self.max_budget is None:
|
||||
return False
|
||||
return self.spend >= self.max_budget
|
||||
|
||||
def has_model_access(self, model_name: str) -> bool:
|
||||
if not self.models:
|
||||
return True
|
||||
return model_name in self.models
|
||||
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Verification token table model.
|
||||
|
||||
Canonical definition for ``litellm_verificationtoken``. Re-exported from
|
||||
``litellm.proxy._types`` for backwards compatibility.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from litellm.models.object_permission import LiteLLM_ObjectPermissionTable
|
||||
from litellm.types.llms.base import LiteLLMPydanticObjectBase
|
||||
|
||||
|
||||
class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
|
||||
token: Optional[str] = None
|
||||
key_name: Optional[str] = None
|
||||
key_alias: Optional[str] = None
|
||||
spend: float = 0.0
|
||||
max_budget: Optional[float] = None
|
||||
expires: Optional[Union[str, datetime]] = None
|
||||
models: List = []
|
||||
aliases: Dict = {}
|
||||
config: Dict = {}
|
||||
user_id: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
metadata: Dict = {}
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
allowed_cache_controls: Optional[list] = []
|
||||
allowed_routes: Optional[list] = []
|
||||
permissions: Dict = {}
|
||||
model_spend: Dict = {}
|
||||
model_max_budget: Dict = {}
|
||||
soft_budget_cooldown: bool = False
|
||||
blocked: Optional[bool] = None
|
||||
litellm_budget_table: Optional[dict] = None
|
||||
budget_id: Optional[str] = None
|
||||
org_id: Optional[str] = None # org id for a given key
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_by: Optional[str] = None
|
||||
last_active: Optional[datetime] = None
|
||||
object_permission_id: Optional[str] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
access_group_ids: Optional[List[str]] = None
|
||||
rotation_count: Optional[int] = 0
|
||||
auto_rotate: Optional[bool] = False
|
||||
rotation_interval: Optional[str] = None
|
||||
last_rotation_at: Optional[datetime] = None
|
||||
key_rotation_at: Optional[datetime] = None
|
||||
router_settings: Optional[dict] = None
|
||||
budget_limits: Optional[List[dict]] = None
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_DeletedVerificationToken(LiteLLM_VerificationToken):
|
||||
"""Audit record for deleted keys; mirrors the token plus deletion metadata."""
|
||||
|
||||
id: Optional[str] = None
|
||||
deleted_at: Optional[datetime] = None
|
||||
deleted_by: Optional[str] = None
|
||||
deleted_by_api_key: Optional[str] = None
|
||||
litellm_changed_by: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
@@ -14,8 +14,12 @@ from litellm.proxy._types import (
|
||||
SpecialHeaders,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.repositories.table_repositories import (
|
||||
AgentsRepository,
|
||||
MCPServerRepository,
|
||||
)
|
||||
|
||||
|
||||
def _parse_mcp_server_names_from_path(
|
||||
@@ -1445,7 +1449,7 @@ class MCPRequestHandler:
|
||||
return None
|
||||
|
||||
if object_permission_id is None:
|
||||
agent_row = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
agent_row = await AgentsRepository(prisma_client).table.find_unique(
|
||||
where={"agent_id": agent_id},
|
||||
)
|
||||
object_permission_id = (
|
||||
@@ -1600,7 +1604,7 @@ class MCPRequestHandler:
|
||||
server_ids: Set[str] = set()
|
||||
if access_groups and prisma_client is not None:
|
||||
try:
|
||||
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
mcp_servers = await MCPServerRepository(prisma_client).table.find_many(
|
||||
where={"mcp_access_groups": {"hasSome": access_groups}}
|
||||
)
|
||||
for server in mcp_servers:
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.constants import MCP_PER_USER_TOKEN_EXPIRY_BUFFER_SECONDS
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_MCPServerTable,
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
@@ -25,8 +26,16 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
MCPServerRepository,
|
||||
MCPUserCredentialsRepository,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.mcp import MCPCredentials
|
||||
|
||||
@@ -354,7 +363,7 @@ async def get_all_mcp_servers(
|
||||
where: Dict[str, Any] = {}
|
||||
if approval_status is not None:
|
||||
where["approval_status"] = approval_status
|
||||
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
mcp_servers = await MCPServerRepository(prisma_client).table.find_many(
|
||||
where=where if where else {}
|
||||
)
|
||||
|
||||
@@ -380,7 +389,9 @@ async def get_mcp_server(
|
||||
"""
|
||||
Returns the matching mcp server from the db iff exists
|
||||
"""
|
||||
mcp_server = await prisma_client.db.litellm_mcpservertable.find_unique(
|
||||
mcp_server: Optional[LiteLLM_MCPServerTable] = await MCPServerRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={
|
||||
"server_id": server_id,
|
||||
}
|
||||
@@ -398,12 +409,12 @@ async def get_mcp_servers(
|
||||
"""
|
||||
Returns the matching mcp servers from the db with the server_ids
|
||||
"""
|
||||
_mcp_servers: List[LiteLLM_MCPServerTable] = (
|
||||
await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
where={
|
||||
"server_id": {"in": server_ids},
|
||||
}
|
||||
)
|
||||
_mcp_servers: List[LiteLLM_MCPServerTable] = await MCPServerRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where={
|
||||
"server_id": {"in": server_ids},
|
||||
}
|
||||
)
|
||||
final_mcp_servers: List[LiteLLM_MCPServerTable] = []
|
||||
for _mcp_server in _mcp_servers:
|
||||
@@ -420,15 +431,15 @@ async def get_mcp_servers_by_verificationtoken(
|
||||
"""
|
||||
Returns the mcp servers from the db for the verification token
|
||||
"""
|
||||
verification_token_record: LiteLLM_TeamTable = (
|
||||
await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": token,
|
||||
},
|
||||
include={
|
||||
"object_permission": True,
|
||||
},
|
||||
)
|
||||
verification_token_record: LiteLLM_TeamTable = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={
|
||||
"token": token,
|
||||
},
|
||||
include={
|
||||
"object_permission": True,
|
||||
},
|
||||
)
|
||||
|
||||
mcp_servers: Optional[List[str]] = []
|
||||
@@ -446,15 +457,15 @@ async def get_mcp_servers_by_team(
|
||||
"""
|
||||
Returns the mcp servers from the db for the team id
|
||||
"""
|
||||
team_record: LiteLLM_TeamTable = (
|
||||
await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={
|
||||
"team_id": team_id,
|
||||
},
|
||||
include={
|
||||
"object_permission": True,
|
||||
},
|
||||
)
|
||||
team_record: LiteLLM_TeamTable = await TeamRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={
|
||||
"team_id": team_id,
|
||||
},
|
||||
include={
|
||||
"object_permission": True,
|
||||
},
|
||||
)
|
||||
|
||||
mcp_servers: Optional[List[str]] = []
|
||||
@@ -505,16 +516,16 @@ async def get_objectpermissions_for_mcp_server(
|
||||
"""
|
||||
Get all the object permissions records and the associated team and verficiationtoken records that have access to the mcp server
|
||||
"""
|
||||
object_permission_records = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.find_many(
|
||||
where={
|
||||
"mcp_servers": {"has": mcp_server_id},
|
||||
},
|
||||
include={
|
||||
"teams": True,
|
||||
"verification_tokens": True,
|
||||
},
|
||||
)
|
||||
object_permission_records = await ObjectPermissionRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where={
|
||||
"mcp_servers": {"has": mcp_server_id},
|
||||
},
|
||||
include={
|
||||
"teams": True,
|
||||
"verification_tokens": True,
|
||||
},
|
||||
)
|
||||
|
||||
return object_permission_records
|
||||
@@ -526,7 +537,7 @@ async def get_virtualkeys_for_mcp_server(
|
||||
"""
|
||||
Get all the virtual keys that have access to the mcp server
|
||||
"""
|
||||
virtual_keys = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
virtual_keys = await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where={
|
||||
"mcp_servers": {"has": server_id},
|
||||
},
|
||||
@@ -564,7 +575,7 @@ async def delete_mcp_server(
|
||||
|
||||
Returns the deleted mcp server record if it exists, otherwise None
|
||||
"""
|
||||
deleted_server = await prisma_client.db.litellm_mcpservertable.delete(
|
||||
deleted_server = await MCPServerRepository(prisma_client).table.delete(
|
||||
where={
|
||||
"server_id": server_id,
|
||||
},
|
||||
@@ -600,7 +611,7 @@ async def create_mcp_server(
|
||||
data_dict["created_by"] = touched_by
|
||||
data_dict["updated_by"] = touched_by
|
||||
|
||||
new_mcp_server = await prisma_client.db.litellm_mcpservertable.create(
|
||||
new_mcp_server = await MCPServerRepository(prisma_client).table.create(
|
||||
data=data_dict # type: ignore
|
||||
)
|
||||
|
||||
@@ -635,7 +646,7 @@ async def update_mcp_server(
|
||||
"credentials" in data_dict and data_dict["credentials"] is not None
|
||||
)
|
||||
if data.auth_type or has_credentials:
|
||||
existing = await prisma_client.db.litellm_mcpservertable.find_unique(
|
||||
existing = await MCPServerRepository(prisma_client).table.find_unique(
|
||||
where={"server_id": data.server_id}
|
||||
)
|
||||
|
||||
@@ -678,7 +689,7 @@ async def update_mcp_server(
|
||||
# Add audit fields
|
||||
data_dict["updated_by"] = touched_by
|
||||
|
||||
updated_mcp_server = await prisma_client.db.litellm_mcpservertable.update(
|
||||
updated_mcp_server = await MCPServerRepository(prisma_client).table.update(
|
||||
where={"server_id": data.server_id}, data=data_dict # type: ignore
|
||||
)
|
||||
|
||||
@@ -691,7 +702,7 @@ async def rotate_mcp_server_credentials_master_key(
|
||||
):
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
|
||||
mcp_servers = await MCPServerRepository(prisma_client).table.find_many()
|
||||
|
||||
updated = 0
|
||||
for mcp_server in mcp_servers:
|
||||
@@ -719,7 +730,7 @@ async def rotate_mcp_server_credentials_master_key(
|
||||
continue
|
||||
|
||||
update_data["updated_by"] = touched_by
|
||||
await prisma_client.db.litellm_mcpservertable.update(
|
||||
await MCPServerRepository(prisma_client).table.update(
|
||||
where={"server_id": mcp_server.server_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -781,7 +792,7 @@ async def rotate_mcp_user_credentials_master_key(
|
||||
under the new master key. Rows that are unreadable under both paths
|
||||
are logged and skipped so one corrupt row does not abort the rotation.
|
||||
"""
|
||||
rows = await prisma_client.db.litellm_mcpusercredentials.find_many()
|
||||
rows = await MCPUserCredentialsRepository(prisma_client).table.find_many()
|
||||
rotated = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
@@ -798,7 +809,7 @@ async def rotate_mcp_user_credentials_master_key(
|
||||
re_encrypted = encrypt_value_helper(
|
||||
plaintext, new_encryption_key=new_master_key
|
||||
)
|
||||
await prisma_client.db.litellm_mcpusercredentials.update(
|
||||
await MCPUserCredentialsRepository(prisma_client).table.update(
|
||||
where={
|
||||
"user_id_server_id": {
|
||||
"user_id": row.user_id,
|
||||
@@ -873,7 +884,7 @@ async def store_user_credential(
|
||||
"""Store a user credential for a BYOK MCP server."""
|
||||
|
||||
encoded = encrypt_value_helper(credential)
|
||||
await prisma_client.db.litellm_mcpusercredentials.upsert(
|
||||
await MCPUserCredentialsRepository(prisma_client).table.upsert(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
|
||||
data={
|
||||
"create": {
|
||||
@@ -893,7 +904,7 @@ async def get_user_credential(
|
||||
) -> Optional[str]:
|
||||
"""Return credential for a user+server pair, or None."""
|
||||
|
||||
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
row = await MCPUserCredentialsRepository(prisma_client).table.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
if row is None:
|
||||
@@ -907,7 +918,7 @@ async def has_user_credential(
|
||||
server_id: str,
|
||||
) -> bool:
|
||||
"""Return True if the user has a stored credential for this server."""
|
||||
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
row = await MCPUserCredentialsRepository(prisma_client).table.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
return row is not None
|
||||
@@ -919,7 +930,7 @@ async def delete_user_credential(
|
||||
server_id: str,
|
||||
) -> None:
|
||||
"""Delete the user's stored credential for a BYOK MCP server."""
|
||||
await prisma_client.db.litellm_mcpusercredentials.delete(
|
||||
await MCPUserCredentialsRepository(prisma_client).table.delete(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
|
||||
@@ -966,7 +977,7 @@ async def store_user_oauth_credential(
|
||||
# Skip the guard when the caller knows the row is already an OAuth2 credential
|
||||
# (e.g. during token refresh), saving an extra DB round-trip.
|
||||
if not skip_byok_guard:
|
||||
existing = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
existing = await MCPUserCredentialsRepository(prisma_client).table.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
if (
|
||||
@@ -984,7 +995,7 @@ async def store_user_oauth_credential(
|
||||
)
|
||||
|
||||
encoded = encrypt_value_helper(json.dumps(payload))
|
||||
await prisma_client.db.litellm_mcpusercredentials.upsert(
|
||||
await MCPUserCredentialsRepository(prisma_client).table.upsert(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
|
||||
data={
|
||||
"create": {
|
||||
@@ -1025,7 +1036,7 @@ async def get_user_oauth_credential(
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Return the decoded OAuth2 payload dict for a user+server pair, or None."""
|
||||
|
||||
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
row = await MCPUserCredentialsRepository(prisma_client).table.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
if row is None:
|
||||
@@ -1039,7 +1050,7 @@ async def list_user_oauth_credentials(
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return all OAuth2 credential payloads for a user, tagged with server_id."""
|
||||
|
||||
rows = await prisma_client.db.litellm_mcpusercredentials.find_many(
|
||||
rows = await MCPUserCredentialsRepository(prisma_client).table.find_many(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
results: List[Dict[str, Any]] = []
|
||||
@@ -1212,7 +1223,7 @@ async def approve_mcp_server(
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""Set approval_status=active and record reviewed_at."""
|
||||
now = datetime.now(timezone.utc)
|
||||
updated = await prisma_client.db.litellm_mcpservertable.update(
|
||||
updated = await MCPServerRepository(prisma_client).table.update(
|
||||
where={"server_id": server_id},
|
||||
data={
|
||||
"approval_status": MCPApprovalStatus.active,
|
||||
@@ -1240,7 +1251,7 @@ async def reject_mcp_server(
|
||||
}
|
||||
if review_notes is not None:
|
||||
data["review_notes"] = review_notes
|
||||
updated = await prisma_client.db.litellm_mcpservertable.update(
|
||||
updated = await MCPServerRepository(prisma_client).table.update(
|
||||
where={"server_id": server_id},
|
||||
data=data,
|
||||
)
|
||||
@@ -1257,7 +1268,7 @@ async def get_mcp_submissions(
|
||||
along with a summary count breakdown by approval_status.
|
||||
Mirrors get_guardrail_submissions() from guardrail_endpoints.py.
|
||||
"""
|
||||
rows = await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
rows = await MCPServerRepository(prisma_client).table.find_many(
|
||||
where={"submitted_at": {"not": None}},
|
||||
order={"submitted_at": "desc"},
|
||||
take=500, # safety cap; paginate if needed in a future iteration
|
||||
|
||||
@@ -42,8 +42,8 @@ from litellm.constants import (
|
||||
MCP_TOOL_LISTING_TIMEOUT,
|
||||
)
|
||||
from litellm.exceptions import BlockedPiiEntityError, GuardrailRaisedException
|
||||
from litellm.litellm_core_utils.url_utils import SSRFError, async_safe_get
|
||||
from litellm.experimental_mcp_client.client import MCPClient, MCPSigV4Auth
|
||||
from litellm.litellm_core_utils.url_utils import SSRFError, async_safe_get
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
|
||||
MCPRequestHandler,
|
||||
@@ -85,6 +85,7 @@ from litellm.proxy._types import (
|
||||
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.repositories.table_repositories import MCPServerRepository
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.mcp import MCPAuth, MCPStdioConfig
|
||||
from litellm.types.mcp_server.mcp_server_manager import (
|
||||
@@ -3817,7 +3818,7 @@ class MCPServerManager:
|
||||
# Pending/rejected servers are excluded at the DB level so we never load them.
|
||||
from litellm.proxy._experimental.mcp_server.db import LiteLLM_MCPServerTable
|
||||
|
||||
raw_rows = await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
raw_rows = await MCPServerRepository(prisma_client).table.find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"approval_status": None},
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List, Optional
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.table_repositories import MCPToolsetRepository
|
||||
from litellm.types.mcp_server.mcp_toolset import (
|
||||
MCPToolset,
|
||||
NewMCPToolsetRequest,
|
||||
@@ -30,7 +31,7 @@ async def create_mcp_toolset(
|
||||
data_dict["tools"] = json.dumps(data_dict.get("tools", []))
|
||||
data_dict["created_by"] = touched_by
|
||||
data_dict["updated_by"] = touched_by
|
||||
row = await prisma_client.db.litellm_mcptoolsettable.create(data=data_dict)
|
||||
row = await MCPToolsetRepository(prisma_client).table.create(data=data_dict)
|
||||
return _toolset_from_row(row)
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ async def get_mcp_toolset(
|
||||
prisma_client: PrismaClient,
|
||||
toolset_id: str,
|
||||
) -> Optional[MCPToolset]:
|
||||
row = await prisma_client.db.litellm_mcptoolsettable.find_unique(
|
||||
row = await MCPToolsetRepository(prisma_client).table.find_unique(
|
||||
where={"toolset_id": toolset_id}
|
||||
)
|
||||
if row is None:
|
||||
@@ -54,7 +55,7 @@ async def list_mcp_toolsets(
|
||||
where = {}
|
||||
if toolset_ids is not None:
|
||||
where = {"toolset_id": {"in": toolset_ids}}
|
||||
rows = await prisma_client.db.litellm_mcptoolsettable.find_many(where=where)
|
||||
rows = await MCPToolsetRepository(prisma_client).table.find_many(where=where)
|
||||
return [_toolset_from_row(r) for r in rows]
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
@@ -69,7 +70,7 @@ async def get_mcp_toolset_by_name(
|
||||
prisma_client: PrismaClient,
|
||||
toolset_name: str,
|
||||
) -> Optional[MCPToolset]:
|
||||
row = await prisma_client.db.litellm_mcptoolsettable.find_first(
|
||||
row = await MCPToolsetRepository(prisma_client).table.find_first(
|
||||
where={"toolset_name": toolset_name}
|
||||
)
|
||||
if row is None:
|
||||
@@ -87,7 +88,7 @@ async def update_mcp_toolset(
|
||||
data_dict["tools"] = json.dumps(data_dict["tools"])
|
||||
data_dict["updated_by"] = touched_by
|
||||
try:
|
||||
row = await prisma_client.db.litellm_mcptoolsettable.update(
|
||||
row = await MCPToolsetRepository(prisma_client).table.update(
|
||||
where={"toolset_id": data.toolset_id},
|
||||
data=data_dict,
|
||||
)
|
||||
@@ -105,7 +106,7 @@ async def delete_mcp_toolset(
|
||||
toolset_id: str,
|
||||
) -> Optional[MCPToolset]:
|
||||
try:
|
||||
row = await prisma_client.db.litellm_mcptoolsettable.delete(
|
||||
row = await MCPToolsetRepository(prisma_client).table.delete(
|
||||
where={"toolset_id": toolset_id}
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
+84
-671
@@ -23,8 +23,6 @@ from litellm.litellm_core_utils.initialize_dynamic_callback_params import (
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIFileObject,
|
||||
ResponsesAPIResponse,
|
||||
)
|
||||
from litellm.types.mcp import (
|
||||
MCPAuthType,
|
||||
@@ -41,8 +39,6 @@ from litellm.types.utils import (
|
||||
EmbeddingResponse,
|
||||
GenericBudgetConfigType,
|
||||
ImageResponse,
|
||||
LiteLLMBatch,
|
||||
LiteLLMFineTuningJob,
|
||||
LiteLLMPydanticObjectBase,
|
||||
ModelResponse,
|
||||
ProviderField,
|
||||
@@ -1014,12 +1010,7 @@ class LiteLLM_ObjectPermissionBase(LiteLLMPydanticObjectBase):
|
||||
search_tools: Optional[List[str]] = None
|
||||
|
||||
|
||||
class BudgetLimitEntry(LiteLLMPydanticObjectBase):
|
||||
"""A single budget window with its own limit and independent reset schedule."""
|
||||
|
||||
budget_duration: str # e.g. "24h", "7d", "30d"
|
||||
max_budget: float # max spend in USD for this window
|
||||
reset_at: Optional[datetime] = None # populated at creation/reset time
|
||||
from litellm.models.team import BudgetLimitEntry as BudgetLimitEntry # noqa: E402
|
||||
|
||||
|
||||
class GenerateRequestBase(LiteLLMPydanticObjectBase):
|
||||
@@ -1217,40 +1208,10 @@ class KeyRequest(LiteLLMPydanticObjectBase):
|
||||
return values
|
||||
|
||||
|
||||
class LiteLLM_ModelTable(LiteLLMPydanticObjectBase):
|
||||
id: Optional[int] = None
|
||||
model_aliases: Optional[Union[str, dict]] = None # json dump the dict
|
||||
created_by: str
|
||||
updated_by: str
|
||||
team: Optional["LiteLLM_TeamTable"] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_ProxyModelTable(LiteLLMPydanticObjectBase):
|
||||
model_id: str
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
model_info: dict
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: str
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_by: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_potential_json_str(cls, values):
|
||||
if isinstance(values.get("litellm_params"), str):
|
||||
try:
|
||||
values["litellm_params"] = json.loads(values["litellm_params"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if isinstance(values.get("model_info"), str):
|
||||
try:
|
||||
values["model_info"] = json.loads(values["model_info"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return values
|
||||
from litellm.models.model import ( # noqa: E402
|
||||
LiteLLM_ProxyModelTable as LiteLLM_ProxyModelTable,
|
||||
)
|
||||
from litellm.models.team import LiteLLM_ModelTable as LiteLLM_ModelTable # noqa: E402
|
||||
|
||||
|
||||
# MCP Types
|
||||
@@ -1265,32 +1226,12 @@ class MCPApprovalStatus(str, enum.Enum):
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
class MCPEnvVarScope(str, enum.Enum):
|
||||
"""Scope for an MCP server environment variable.
|
||||
|
||||
- ``global``: value is provided by the admin and used for all users.
|
||||
- ``user``: each user must provide their own value via the per-user
|
||||
env-var endpoint. The admin-supplied ``value`` is treated as a
|
||||
placeholder/hint and is not used at request time.
|
||||
"""
|
||||
|
||||
global_ = "global"
|
||||
user = "user"
|
||||
|
||||
|
||||
class MCPEnvVar(LiteLLMPydanticObjectBase):
|
||||
"""One environment variable for an MCP server.
|
||||
|
||||
Variables can be interpolated into ``static_headers`` using ``${NAME}``
|
||||
syntax. ``scope=global`` values are stored on the server. ``scope=user``
|
||||
values are stored per-user in ``LiteLLM_MCPUserEnvVars`` and supplied by
|
||||
each user.
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str = ""
|
||||
scope: MCPEnvVarScope = MCPEnvVarScope.global_
|
||||
description: Optional[str] = None
|
||||
from litellm.models.mcp_server import ( # noqa: E402
|
||||
MCPEnvVar as MCPEnvVar,
|
||||
)
|
||||
from litellm.models.mcp_server import ( # noqa: E402
|
||||
MCPEnvVarScope as MCPEnvVarScope,
|
||||
)
|
||||
|
||||
|
||||
# MCP Proxy Request Types
|
||||
@@ -1443,66 +1384,9 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
return values
|
||||
|
||||
|
||||
class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase):
|
||||
"""Represents a LiteLLM_MCPServerTable record"""
|
||||
|
||||
server_id: str
|
||||
server_name: Optional[str] = None
|
||||
alias: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
spec_path: Optional[str] = None
|
||||
transport: MCPTransportType
|
||||
auth_type: Optional[MCPAuthType] = None
|
||||
credentials: Optional[MCPCredentials] = None
|
||||
instructions: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_by: Optional[str] = None
|
||||
teams: List[Dict[str, Optional[str]]] = Field(default_factory=list)
|
||||
mcp_access_groups: List[str] = Field(default_factory=list)
|
||||
allowed_tools: List[str] = Field(default_factory=list)
|
||||
tool_name_to_display_name: Optional[Dict[str, str]] = None
|
||||
tool_name_to_description: Optional[Dict[str, str]] = None
|
||||
extra_headers: List[str] = Field(default_factory=list)
|
||||
mcp_info: Optional[MCPInfo] = None
|
||||
static_headers: Optional[Dict[str, str]] = None
|
||||
env_vars: Optional[List[MCPEnvVar]] = None
|
||||
# Health check status
|
||||
status: Optional[Literal["healthy", "unhealthy", "unknown"]] = Field(
|
||||
default="unknown",
|
||||
description="Health status: 'healthy', 'unhealthy', 'unknown'",
|
||||
)
|
||||
last_health_check: Optional[datetime] = None
|
||||
health_check_error: Optional[str] = None
|
||||
# Stdio-specific fields
|
||||
command: Optional[str] = None
|
||||
args: List[str] = Field(default_factory=list)
|
||||
env: Dict[str, str] = Field(default_factory=dict)
|
||||
authorization_url: Optional[str] = None
|
||||
token_url: Optional[str] = None
|
||||
registration_url: Optional[str] = None
|
||||
oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = None
|
||||
allow_all_keys: bool = False
|
||||
available_on_public_internet: bool = True
|
||||
delegate_auth_to_upstream: bool = False
|
||||
oauth_passthrough: bool = False
|
||||
is_byok: bool = False
|
||||
byok_description: List[str] = Field(default_factory=list)
|
||||
byok_api_key_help_url: Optional[str] = None
|
||||
has_user_credential: Optional[bool] = None
|
||||
source_url: Optional[str] = None
|
||||
timeout: Optional[float] = None
|
||||
# BYOM submission fields
|
||||
approval_status: Optional[str] = Field(
|
||||
default="active",
|
||||
description="Approval status: 'pending_review', 'active', 'rejected'",
|
||||
)
|
||||
submitted_by: Optional[str] = None
|
||||
submitted_at: Optional[datetime] = None
|
||||
reviewed_at: Optional[datetime] = None
|
||||
review_notes: Optional[str] = None
|
||||
from litellm.models.mcp_server import ( # noqa: E402
|
||||
LiteLLM_MCPServerTable as LiteLLM_MCPServerTable,
|
||||
)
|
||||
|
||||
|
||||
class MakeMCPServersPublicRequest(LiteLLMPydanticObjectBase):
|
||||
@@ -1622,23 +1506,9 @@ class UpdateSkillRequest(LiteLLMPydanticObjectBase):
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class LiteLLM_SkillsTable(LiteLLMPydanticObjectBase):
|
||||
"""Represents a LiteLLM_SkillsTable record"""
|
||||
|
||||
skill_id: str
|
||||
display_title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
source: str = "custom"
|
||||
latest_version: Optional[str] = None
|
||||
file_content: Optional[bytes] = None # Binary content of skill files (zip)
|
||||
file_name: Optional[str] = None # Original filename
|
||||
file_type: Optional[str] = None # MIME type
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_by: Optional[str] = None
|
||||
from litellm.models.skills import ( # noqa: E402
|
||||
LiteLLM_SkillsTable as LiteLLM_SkillsTable,
|
||||
)
|
||||
|
||||
|
||||
class ListSkillsRequest(LiteLLMPydanticObjectBase):
|
||||
@@ -1839,33 +1709,8 @@ class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
|
||||
user_ids: List[str]
|
||||
|
||||
|
||||
class MemberBase(LiteLLMPydanticObjectBase):
|
||||
user_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique ID of the user to add. Either user_id or user_email must be provided",
|
||||
)
|
||||
user_email: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The email address of the user to add. Either user_id or user_email must be provided",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_user_info(cls, values):
|
||||
if not isinstance(values, dict):
|
||||
raise ValueError("input needs to be a dictionary")
|
||||
if values.get("user_id") is None and values.get("user_email") is None:
|
||||
raise ValueError("Either user id or user email must be provided")
|
||||
return values
|
||||
|
||||
|
||||
class Member(MemberBase):
|
||||
role: Literal[
|
||||
"admin",
|
||||
"user",
|
||||
] = Field(
|
||||
description="The role of the user within the team. 'admin' users can manage team settings and members, 'user' is a regular team member"
|
||||
)
|
||||
from litellm.models.team import Member as Member # noqa: E402
|
||||
from litellm.models.team import MemberBase as MemberBase # noqa: E402
|
||||
|
||||
|
||||
class OrgMember(MemberBase):
|
||||
@@ -1876,33 +1721,7 @@ class OrgMember(MemberBase):
|
||||
]
|
||||
|
||||
|
||||
class TeamBase(LiteLLMPydanticObjectBase):
|
||||
team_alias: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
admins: list = []
|
||||
members: list = []
|
||||
members_with_roles: List[Member] = []
|
||||
team_member_permissions: Optional[List[str]] = None
|
||||
metadata: Optional[dict] = None
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
|
||||
# Budget fields
|
||||
max_budget: Optional[float] = None
|
||||
soft_budget: Optional[float] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_limits: Optional[List[BudgetLimitEntry]] = (
|
||||
None # multiple concurrent budget windows
|
||||
)
|
||||
|
||||
models: list = []
|
||||
blocked: bool = False
|
||||
router_settings: Optional[dict] = None
|
||||
access_group_ids: Optional[List[str]] = None
|
||||
default_team_member_models: Optional[List[str]] = (
|
||||
None # default allowed_models seeded onto new team members
|
||||
)
|
||||
from litellm.models.team import TeamBase as TeamBase # noqa: E402
|
||||
|
||||
|
||||
class NewTeamRequest(TeamBase):
|
||||
@@ -2100,147 +1919,31 @@ class TeamCallbackMetadata(LiteLLMPydanticObjectBase):
|
||||
return values
|
||||
|
||||
|
||||
class LiteLLM_ObjectPermissionTable(LiteLLMPydanticObjectBase):
|
||||
"""Represents a LiteLLM_ObjectPermissionTable record"""
|
||||
|
||||
object_permission_id: str
|
||||
mcp_servers: Optional[List[str]] = []
|
||||
mcp_access_groups: Optional[List[str]] = []
|
||||
mcp_tool_permissions: Optional[Dict[str, List[str]]] = None
|
||||
"""
|
||||
Mapping - server_id -> list of tools
|
||||
|
||||
Enforces allowed tools for a specific key/team/organization
|
||||
{
|
||||
"1234567890": ["tool_name_1", "tool_name_2"]
|
||||
}
|
||||
"""
|
||||
|
||||
vector_stores: Optional[List[str]] = []
|
||||
agents: Optional[List[str]] = []
|
||||
agent_access_groups: Optional[List[str]] = []
|
||||
mcp_toolsets: Optional[List[str]] = None
|
||||
blocked_tools: Optional[List[str]] = []
|
||||
search_tools: Optional[List[str]] = []
|
||||
|
||||
|
||||
class LiteLLM_TeamTable(TeamBase):
|
||||
team_id: str # type: ignore
|
||||
spend: Optional[float] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
model_id: Optional[int] = None
|
||||
litellm_model_table: Optional[LiteLLM_ModelTable] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
#########################################################
|
||||
# Object Permission - MCP, Vector Stores etc.
|
||||
#########################################################
|
||||
object_permission_id: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_model_info(cls, values):
|
||||
dict_fields = [
|
||||
"metadata",
|
||||
"aliases",
|
||||
"config",
|
||||
"permissions",
|
||||
"model_max_budget",
|
||||
"model_aliases",
|
||||
"router_settings",
|
||||
"budget_limits",
|
||||
]
|
||||
|
||||
if isinstance(values, BaseModel):
|
||||
values = values.model_dump()
|
||||
|
||||
if (
|
||||
isinstance(values.get("members_with_roles"), dict)
|
||||
and not values["members_with_roles"]
|
||||
):
|
||||
values["members_with_roles"] = []
|
||||
|
||||
for field in dict_fields:
|
||||
value = values.get(field)
|
||||
if value is not None and isinstance(value, str):
|
||||
try:
|
||||
values[field] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Field {field} should be a valid dictionary")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class LiteLLM_TeamTableCachedObj(LiteLLM_TeamTable):
|
||||
last_refreshed_at: Optional[float] = None
|
||||
|
||||
|
||||
class LiteLLM_DeletedTeamTable(LiteLLM_TeamTable):
|
||||
"""
|
||||
Recording of deleted teams for audit purposes. Mirrors LiteLLM_TeamTable
|
||||
plus metadata captured at deletion time.
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
deleted_at: Optional[datetime] = None
|
||||
deleted_by: Optional[str] = None
|
||||
deleted_by_api_key: Optional[str] = None
|
||||
litellm_changed_by: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
from litellm.models.object_permission import ( # noqa: E402
|
||||
LiteLLM_ObjectPermissionTable as LiteLLM_ObjectPermissionTable,
|
||||
)
|
||||
from litellm.models.team import ( # noqa: E402
|
||||
LiteLLM_DeletedTeamTable as LiteLLM_DeletedTeamTable,
|
||||
)
|
||||
from litellm.models.team import LiteLLM_TeamTable as LiteLLM_TeamTable # noqa: E402
|
||||
from litellm.models.team import ( # noqa: E402
|
||||
LiteLLM_TeamTableCachedObj as LiteLLM_TeamTableCachedObj,
|
||||
)
|
||||
|
||||
|
||||
class TeamRequest(LiteLLMPydanticObjectBase):
|
||||
teams: List[str]
|
||||
|
||||
|
||||
class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase):
|
||||
"""Represents user-controllable params for a LiteLLM_BudgetTable record.
|
||||
|
||||
Budget-write paths use `model_fields.keys()` on this class as an allowlist
|
||||
for user input. Keep server-managed fields (e.g. `budget_reset_at`) on
|
||||
`LiteLLM_BudgetTableFull` so they aren't user-settable.
|
||||
"""
|
||||
|
||||
budget_id: Optional[str] = None
|
||||
soft_budget: Optional[float] = None
|
||||
max_budget: Optional[float] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
model_max_budget: Optional[dict] = None
|
||||
budget_duration: Optional[str] = None
|
||||
allowed_models: Optional[List[str]] = (
|
||||
None # per-member model scope; empty = inherit team models
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable):
|
||||
"""LiteLLM_BudgetTable + server-managed fields returned on API responses."""
|
||||
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
|
||||
"""
|
||||
Used to track spend of a user_id within a team_id
|
||||
"""
|
||||
|
||||
spend: Optional[float] = None
|
||||
user_id: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
budget_id: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
from litellm.models.budget import ( # noqa: E402
|
||||
LiteLLM_BudgetTable as LiteLLM_BudgetTable,
|
||||
)
|
||||
from litellm.models.budget import ( # noqa: E402
|
||||
LiteLLM_BudgetTableFull as LiteLLM_BudgetTableFull,
|
||||
)
|
||||
from litellm.models.budget import ( # noqa: E402
|
||||
LiteLLM_TeamMemberTable as LiteLLM_TeamMemberTable,
|
||||
)
|
||||
|
||||
|
||||
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
||||
@@ -2637,66 +2340,12 @@ class ConfigYAML(LiteLLMPydanticObjectBase):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
|
||||
token: Optional[str] = None
|
||||
key_name: Optional[str] = None
|
||||
key_alias: Optional[str] = None
|
||||
spend: float = 0.0
|
||||
max_budget: Optional[float] = None
|
||||
expires: Optional[Union[str, datetime]] = None
|
||||
models: List = []
|
||||
aliases: Dict = {}
|
||||
config: Dict = {}
|
||||
user_id: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
metadata: Dict = {}
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
allowed_cache_controls: Optional[list] = []
|
||||
allowed_routes: Optional[list] = []
|
||||
permissions: Dict = {}
|
||||
model_spend: Dict = {}
|
||||
model_max_budget: Dict = {}
|
||||
soft_budget_cooldown: bool = False
|
||||
blocked: Optional[bool] = None
|
||||
litellm_budget_table: Optional[dict] = None
|
||||
org_id: Optional[str] = None # org id for a given key
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_by: Optional[str] = None
|
||||
last_active: Optional[datetime] = None
|
||||
object_permission_id: Optional[str] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
access_group_ids: Optional[List[str]] = None
|
||||
rotation_count: Optional[int] = 0 # Number of times key has been rotated
|
||||
auto_rotate: Optional[bool] = False # Whether this key should be auto-rotated
|
||||
rotation_interval: Optional[str] = None # How often to rotate (e.g., "30d", "90d")
|
||||
last_rotation_at: Optional[datetime] = None # When this key was last rotated
|
||||
key_rotation_at: Optional[datetime] = None # When this key should next be rotated
|
||||
router_settings: Optional[dict] = None
|
||||
budget_limits: Optional[List[dict]] = None # multiple concurrent budget windows
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_DeletedVerificationToken(LiteLLM_VerificationToken):
|
||||
"""
|
||||
Recording of deleted keys for audit purposes. Mirrors LiteLLM_VerificationToken
|
||||
plus metadata captured at deletion time.
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
deleted_at: Optional[datetime] = None
|
||||
deleted_by: Optional[str] = None
|
||||
deleted_by_api_key: Optional[str] = None
|
||||
litellm_changed_by: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
from litellm.models.verification_token import ( # noqa: E402
|
||||
LiteLLM_DeletedVerificationToken as LiteLLM_DeletedVerificationToken,
|
||||
)
|
||||
from litellm.models.verification_token import ( # noqa: E402
|
||||
LiteLLM_VerificationToken as LiteLLM_VerificationToken,
|
||||
)
|
||||
|
||||
|
||||
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||
@@ -2935,39 +2584,10 @@ class UserInfoV2Response(LiteLLMPydanticObjectBase):
|
||||
teams: List[str] = [] # Just team IDs, not full team objects
|
||||
|
||||
|
||||
class LiteLLM_Config(LiteLLMPydanticObjectBase):
|
||||
param_name: str
|
||||
param_value: Dict
|
||||
|
||||
|
||||
class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
|
||||
"""
|
||||
This is the table that track what organizations a user belongs to and users spend within the organization
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
organization_id: str
|
||||
user_role: Optional[str] = None
|
||||
spend: float = 0.0
|
||||
budget_id: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
user: Optional[Any] = (
|
||||
None # You might want to replace 'Any' with a more specific type if available
|
||||
)
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
user_email: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@model_validator(mode="after")
|
||||
def populate_user_email(self) -> "LiteLLM_OrganizationMembershipTable":
|
||||
if self.user_email is None and self.user is not None:
|
||||
if isinstance(self.user, dict):
|
||||
self.user_email = self.user.get("user_email")
|
||||
else:
|
||||
self.user_email = getattr(self.user, "user_email", None)
|
||||
return self
|
||||
from litellm.models.config import LiteLLM_Config as LiteLLM_Config # noqa: E402
|
||||
from litellm.models.organization_membership import ( # noqa: E402
|
||||
LiteLLM_OrganizationMembershipTable as LiteLLM_OrganizationMembershipTable,
|
||||
)
|
||||
|
||||
|
||||
class LiteLLM_OrganizationTableUpdate(LiteLLM_BudgetTable):
|
||||
@@ -2997,61 +2617,10 @@ class LiteLLM_OrganizationTableUpdate(LiteLLM_BudgetTable):
|
||||
return values
|
||||
|
||||
|
||||
class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
|
||||
user_id: str
|
||||
max_budget: Optional[float] = None
|
||||
spend: float = 0.0
|
||||
model_max_budget: Optional[Dict] = {}
|
||||
model_spend: Optional[Dict] = {}
|
||||
user_email: Optional[str] = None
|
||||
user_alias: Optional[str] = None
|
||||
models: list = []
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
user_role: Optional[str] = None
|
||||
organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None
|
||||
teams: List[str] = []
|
||||
sso_user_id: Optional[str] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
metadata: Optional[dict] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_model_info(cls, values):
|
||||
if values.get("spend") is None:
|
||||
values.update({"spend": 0.0})
|
||||
if values.get("models") is None:
|
||||
values.update({"models": []})
|
||||
if values.get("teams") is None:
|
||||
values.update({"teams": []})
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_OrganizationTable(LiteLLMPydanticObjectBase):
|
||||
"""Represents user-controllable params for a LiteLLM_OrganizationTable record"""
|
||||
|
||||
organization_id: Optional[str] = None
|
||||
organization_alias: Optional[str] = None
|
||||
budget_id: str
|
||||
spend: float = 0.0
|
||||
metadata: Optional[dict] = None
|
||||
models: List[str]
|
||||
created_by: str
|
||||
updated_by: str
|
||||
users: Optional[List[LiteLLM_UserTable]] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
|
||||
#########################################################
|
||||
# Object Permission - MCP, Vector Stores etc.
|
||||
#########################################################
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
object_permission_id: Optional[str] = None
|
||||
from litellm.models.organization import ( # noqa: E402
|
||||
LiteLLM_OrganizationTable as LiteLLM_OrganizationTable,
|
||||
)
|
||||
from litellm.models.user import LiteLLM_UserTable as LiteLLM_UserTable # noqa: E402
|
||||
|
||||
|
||||
class LiteLLM_OrganizationTableWithMembers(LiteLLM_OrganizationTable):
|
||||
@@ -3160,28 +2729,9 @@ class DeleteProjectRequest(LiteLLMPydanticObjectBase):
|
||||
project_ids: List[str]
|
||||
|
||||
|
||||
class LiteLLM_ProjectTable(LiteLLMPydanticObjectBase):
|
||||
"""Database model representation for project"""
|
||||
|
||||
project_id: str
|
||||
project_alias: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
budget_id: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
models: List[str] = []
|
||||
spend: float = 0.0
|
||||
model_spend: Optional[dict] = None
|
||||
model_rpm_limit: Optional[dict] = None
|
||||
model_tpm_limit: Optional[dict] = None
|
||||
blocked: bool = False
|
||||
object_permission_id: Optional[str] = None
|
||||
created_by: str
|
||||
updated_by: str
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
from litellm.models.project import ( # noqa: E402
|
||||
LiteLLM_ProjectTable as LiteLLM_ProjectTable,
|
||||
)
|
||||
|
||||
|
||||
class NewProjectResponse(LiteLLM_ProjectTable):
|
||||
@@ -3207,101 +2757,19 @@ class LiteLLM_UserTableWithKeyCount(LiteLLM_UserTable):
|
||||
key_count: int = 0
|
||||
|
||||
|
||||
class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase):
|
||||
user_id: str
|
||||
blocked: bool
|
||||
alias: Optional[str] = None
|
||||
spend: float = 0.0
|
||||
allowed_model_region: Optional[AllowedModelRegion] = None
|
||||
default_model: Optional[str] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
object_permission_id: Optional[str] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_model_info(cls, values):
|
||||
if values.get("spend") is None:
|
||||
values.update({"spend": 0.0})
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_TagTable(LiteLLMPydanticObjectBase):
|
||||
tag_name: str
|
||||
description: Optional[str] = None
|
||||
models: List[str] = []
|
||||
model_info: Optional[dict] = None
|
||||
spend: float = 0.0
|
||||
budget_id: Optional[str] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_model_info(cls, values):
|
||||
if values.get("spend") is None:
|
||||
values.update({"spend": 0.0})
|
||||
if values.get("models") is None:
|
||||
values.update({"models": []})
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LiteLLM_AccessGroupTable(LiteLLMPydanticObjectBase):
|
||||
access_group_id: str
|
||||
access_group_name: str
|
||||
description: Optional[str] = None
|
||||
access_model_names: List[str] = []
|
||||
access_mcp_server_ids: List[str] = []
|
||||
access_agent_ids: List[str] = []
|
||||
assigned_team_ids: List[str] = []
|
||||
assigned_key_ids: List[str] = []
|
||||
created_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_by: Optional[str] = None
|
||||
|
||||
|
||||
class LiteLLM_SpendLogs(LiteLLMPydanticObjectBase):
|
||||
request_id: str
|
||||
api_key: str
|
||||
model: Optional[str] = ""
|
||||
api_base: Optional[str] = ""
|
||||
call_type: str
|
||||
spend: Optional[float] = 0.0
|
||||
total_tokens: Optional[int] = 0
|
||||
prompt_tokens: Optional[int] = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
startTime: Union[str, datetime, None]
|
||||
endTime: Union[str, datetime, None]
|
||||
user: Optional[str] = ""
|
||||
metadata: Optional[Json] = {}
|
||||
cache_hit: Optional[str] = "False"
|
||||
cache_key: Optional[str] = None
|
||||
request_tags: Optional[Json] = None
|
||||
requester_ip_address: Optional[str] = None
|
||||
messages: Optional[Union[str, list, dict]]
|
||||
response: Optional[Union[str, list, dict]]
|
||||
|
||||
|
||||
class LiteLLM_ErrorLogs(LiteLLMPydanticObjectBase):
|
||||
request_id: Optional[str] = str(uuid.uuid4())
|
||||
api_base: Optional[str] = ""
|
||||
model_group: Optional[str] = ""
|
||||
litellm_model_name: Optional[str] = ""
|
||||
model_id: Optional[str] = ""
|
||||
request_kwargs: Optional[dict] = {}
|
||||
exception_type: Optional[str] = ""
|
||||
status_code: Optional[str] = ""
|
||||
exception_string: Optional[str] = ""
|
||||
startTime: Union[str, datetime, None]
|
||||
endTime: Union[str, datetime, None]
|
||||
|
||||
from litellm.models.access_group import ( # noqa: E402
|
||||
LiteLLM_AccessGroupTable as LiteLLM_AccessGroupTable,
|
||||
)
|
||||
from litellm.models.end_user import ( # noqa: E402
|
||||
LiteLLM_EndUserTable as LiteLLM_EndUserTable,
|
||||
)
|
||||
from litellm.models.spend_logs import ( # noqa: E402
|
||||
LiteLLM_ErrorLogs as LiteLLM_ErrorLogs,
|
||||
)
|
||||
from litellm.models.spend_logs import ( # noqa: E402
|
||||
LiteLLM_SpendLogs as LiteLLM_SpendLogs,
|
||||
)
|
||||
from litellm.models.tag import LiteLLM_TagTable as LiteLLM_TagTable # noqa: E402
|
||||
|
||||
AUDIT_ACTIONS = Literal[
|
||||
"created", "updated", "deleted", "blocked", "unblocked", "rotated"
|
||||
@@ -3982,29 +3450,9 @@ class CreatePassThroughEndpoint(LiteLLMPydanticObjectBase):
|
||||
headers: dict
|
||||
|
||||
|
||||
class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase):
|
||||
user_id: str
|
||||
team_id: str
|
||||
budget_id: Optional[str] = None
|
||||
spend: Optional[float] = 0.0
|
||||
total_spend: Optional[float] = 0.0
|
||||
# Union so Pydantic picks Full when data has server-managed fields
|
||||
# (/team/info) and Base when callers/tests construct with only
|
||||
# user-settable fields.
|
||||
litellm_budget_table: Optional[
|
||||
Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable]
|
||||
] = None
|
||||
|
||||
def safe_get_team_member_rpm_limit(self) -> Optional[int]:
|
||||
if self.litellm_budget_table is not None:
|
||||
return self.litellm_budget_table.rpm_limit
|
||||
return None
|
||||
|
||||
def safe_get_team_member_tpm_limit(self) -> Optional[int]:
|
||||
if self.litellm_budget_table is not None:
|
||||
return self.litellm_budget_table.tpm_limit
|
||||
return None
|
||||
|
||||
from litellm.models.team_membership import ( # noqa: E402
|
||||
LiteLLM_TeamMembership as LiteLLM_TeamMembership,
|
||||
)
|
||||
|
||||
#### Organization / Team Member Requests ####
|
||||
|
||||
@@ -4922,39 +4370,18 @@ class ToolDiscoveryQueueItem(TypedDict, total=False):
|
||||
user_agent: Optional[str] # HTTP User-Agent of the caller
|
||||
|
||||
|
||||
class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase):
|
||||
unified_file_id: str
|
||||
file_object: Optional[OpenAIFileObject] = None
|
||||
model_mappings: Dict[str, str]
|
||||
flat_model_file_ids: List[str]
|
||||
created_by: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
updated_by: Optional[str] = None
|
||||
storage_backend: Optional[str] = None
|
||||
storage_url: Optional[str] = None
|
||||
|
||||
|
||||
class LiteLLM_ManagedObjectTable(LiteLLMPydanticObjectBase):
|
||||
unified_object_id: str
|
||||
model_object_id: str
|
||||
file_purpose: Literal["batch", "fine-tune", "response", "container"]
|
||||
file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob, ResponsesAPIResponse]
|
||||
created_by: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
|
||||
|
||||
class LiteLLM_ManagedVectorStoreTable(LiteLLMPydanticObjectBase):
|
||||
"""Table for managing vector stores with target_model_names support."""
|
||||
|
||||
unified_resource_id: str
|
||||
resource_object: Optional[Any] = None # VectorStoreCreateResponse
|
||||
model_mappings: Dict[str, str]
|
||||
flat_model_resource_ids: List[str]
|
||||
created_by: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
updated_by: Optional[str] = None
|
||||
storage_backend: Optional[str] = None
|
||||
storage_url: Optional[str] = None
|
||||
from litellm.models.managed_files import ( # noqa: E402
|
||||
LiteLLM_ManagedFileTable as LiteLLM_ManagedFileTable,
|
||||
)
|
||||
from litellm.models.managed_files import ( # noqa: E402
|
||||
LiteLLM_ManagedObjectTable as LiteLLM_ManagedObjectTable,
|
||||
)
|
||||
from litellm.models.managed_files import ( # noqa: E402
|
||||
LiteLLM_ManagedVectorStoresTable as LiteLLM_ManagedVectorStoresTable,
|
||||
)
|
||||
from litellm.models.managed_files import ( # noqa: E402
|
||||
LiteLLM_ManagedVectorStoreTable as LiteLLM_ManagedVectorStoreTable,
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseLicenseData(TypedDict, total=False):
|
||||
@@ -4965,20 +4392,6 @@ class EnterpriseLicenseData(TypedDict, total=False):
|
||||
max_teams: int
|
||||
|
||||
|
||||
class LiteLLM_ManagedVectorStoresTable(LiteLLMPydanticObjectBase):
|
||||
vector_store_id: str
|
||||
custom_llm_provider: str
|
||||
vector_store_name: Optional[str]
|
||||
vector_store_description: Optional[str]
|
||||
vector_store_metadata: Optional[Dict[str, Any]]
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
litellm_credential_name: Optional[str]
|
||||
litellm_params: Optional[Dict[str, Any]]
|
||||
team_id: Optional[str]
|
||||
user_id: Optional[str]
|
||||
|
||||
|
||||
class ResponseLiteLLM_ManagedVectorStore(TypedDict, total=False):
|
||||
vector_store: LiteLLM_ManagedVectorStoresTable
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from litellm.proxy.management_helpers.object_permission_utils import (
|
||||
handle_update_object_permission_common,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.table_repositories import AgentsRepository
|
||||
from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest
|
||||
|
||||
|
||||
@@ -174,7 +175,7 @@ class AgentRegistry:
|
||||
create_data[rate_field] = _val
|
||||
|
||||
# Create agent in DB
|
||||
created_agent = await prisma_client.db.litellm_agentstable.create(
|
||||
created_agent = await AgentsRepository(prisma_client).table.create(
|
||||
data=create_data,
|
||||
include={"object_permission": True},
|
||||
)
|
||||
@@ -200,7 +201,7 @@ class AgentRegistry:
|
||||
Delete an agent from the database
|
||||
"""
|
||||
try:
|
||||
deleted_agent = await prisma_client.db.litellm_agentstable.delete(
|
||||
deleted_agent = await AgentsRepository(prisma_client).table.delete(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
return dict(deleted_agent)
|
||||
@@ -229,7 +230,7 @@ class AgentRegistry:
|
||||
The patched agent
|
||||
"""
|
||||
try:
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
existing_agent = await AgentsRepository(prisma_client).table.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
@@ -282,7 +283,7 @@ class AgentRegistry:
|
||||
if object_permission_id is not None:
|
||||
update_data["object_permission_id"] = object_permission_id
|
||||
# Patch agent in DB
|
||||
patched_agent = await prisma_client.db.litellm_agentstable.update(
|
||||
patched_agent = await AgentsRepository(prisma_client).table.update(
|
||||
where={"agent_id": agent_id},
|
||||
data={
|
||||
**update_data,
|
||||
@@ -368,9 +369,9 @@ class AgentRegistry:
|
||||
update_data[rate_field] = _val
|
||||
|
||||
if agent.get("object_permission") is not None:
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
existing_agent = await AgentsRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"agent_id": agent_id})
|
||||
existing_object_permission_id = (
|
||||
existing_agent.object_permission_id
|
||||
if existing_agent is not None
|
||||
@@ -386,7 +387,7 @@ class AgentRegistry:
|
||||
update_data["object_permission_id"] = object_permission_id
|
||||
|
||||
# Update agent in DB
|
||||
updated_agent = await prisma_client.db.litellm_agentstable.update(
|
||||
updated_agent = await AgentsRepository(prisma_client).table.update(
|
||||
where={"agent_id": agent_id},
|
||||
data=update_data,
|
||||
include={"object_permission": True},
|
||||
@@ -414,7 +415,7 @@ class AgentRegistry:
|
||||
Get all agents from the database
|
||||
"""
|
||||
try:
|
||||
agents_from_db = await prisma_client.db.litellm_agentstable.find_many(
|
||||
agents_from_db = await AgentsRepository(prisma_client).table.find_many(
|
||||
order={"created_at": "desc"},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
@@ -9,11 +9,12 @@ from typing import List, Optional, Set
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.proxy._types import (
|
||||
UI_TEAM_ID,
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_TeamTable,
|
||||
UI_TEAM_ID,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.repositories.table_repositories import AgentsRepository
|
||||
|
||||
|
||||
class AgentRequestHandler:
|
||||
@@ -298,7 +299,7 @@ class AgentRequestHandler:
|
||||
agent_ids: Set[str] = set()
|
||||
if access_groups and prisma_client is not None:
|
||||
try:
|
||||
agents = await prisma_client.db.litellm_agentstable.find_many(
|
||||
agents = await AgentsRepository(prisma_client).table.find_many(
|
||||
where={"agent_access_groups": {"hasSome": access_groups}}
|
||||
)
|
||||
for agent in agents:
|
||||
|
||||
@@ -17,6 +17,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import _get_masked_values
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.a2a.agent_card import merge_agent_card
|
||||
@@ -30,7 +31,6 @@ from litellm.types.agents import (
|
||||
MakeAgentsPublicRequest,
|
||||
PatchAgentRequest,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import _get_masked_values
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
DailySpendMetadata,
|
||||
@@ -219,7 +219,7 @@ async def get_agents(
|
||||
if prisma_client is not None:
|
||||
agent_ids = [agent.agent_id for agent in returned_agents]
|
||||
if agent_ids:
|
||||
db_agents = await prisma_client.db.litellm_agentstable.find_many(
|
||||
db_agents = await AgentsRepository(prisma_client).table.find_many(
|
||||
where={"agent_id": {"in": agent_ids}},
|
||||
)
|
||||
spend_map = {a.agent_id: a.spend for a in db_agents}
|
||||
@@ -301,6 +301,7 @@ async def get_agents(
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry as AGENT_REGISTRY,
|
||||
)
|
||||
from litellm.repositories.table_repositories import AgentsRepository
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -471,7 +472,7 @@ async def get_agent_by_id(
|
||||
try:
|
||||
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
agent_row = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
agent_row = await AgentsRepository(prisma_client).table.find_unique(
|
||||
where={"agent_id": agent_id},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
@@ -489,7 +490,7 @@ async def get_agent_by_id(
|
||||
agent = AgentResponse(**agent_dict) # type: ignore
|
||||
else:
|
||||
# Agent found in memory — refresh spend from DB
|
||||
db_row = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
db_row = await AgentsRepository(prisma_client).table.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if db_row is not None:
|
||||
@@ -570,7 +571,7 @@ async def update_agent(
|
||||
|
||||
try:
|
||||
# Check if agent exists
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
existing_agent = await AgentsRepository(prisma_client).table.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
@@ -678,7 +679,7 @@ async def patch_agent(
|
||||
|
||||
try:
|
||||
# Check if agent exists
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
existing_agent = await AgentsRepository(prisma_client).table.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
@@ -769,7 +770,7 @@ async def delete_agent(
|
||||
|
||||
try:
|
||||
# Check if agent exists
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
existing_agent = await AgentsRepository(prisma_client).table.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if existing_agent is not None:
|
||||
@@ -859,7 +860,7 @@ async def make_agent_public(
|
||||
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
# check if agent exists in DB
|
||||
agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
agent = await AgentsRepository(prisma_client).table.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if agent is not None:
|
||||
@@ -982,7 +983,7 @@ async def make_agents_public(
|
||||
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
# check if agent exists in DB
|
||||
agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
agent = await AgentsRepository(prisma_client).table.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
if agent is not None:
|
||||
@@ -1082,7 +1083,7 @@ async def get_agent_daily_activity(
|
||||
if user_api_key_dict.user_id is None:
|
||||
permitted_agent_ids = []
|
||||
else:
|
||||
owned_records = await prisma_client.db.litellm_agentstable.find_many(
|
||||
owned_records = await AgentsRepository(prisma_client).table.find_many(
|
||||
where={"created_by": user_api_key_dict.user_id}
|
||||
)
|
||||
permitted_agent_ids = [a.agent_id for a in owned_records]
|
||||
@@ -1118,7 +1119,7 @@ async def get_agent_daily_activity(
|
||||
if agent_ids_list:
|
||||
where_condition["agent_id"] = {"in": list(agent_ids_list)}
|
||||
|
||||
agent_records = await prisma_client.db.litellm_agentstable.find_many(
|
||||
agent_records = await AgentsRepository(prisma_client).table.find_many(
|
||||
where=where_condition
|
||||
)
|
||||
agent_metadata = {
|
||||
|
||||
+13
-12
@@ -26,6 +26,7 @@ from fastapi.responses import JSONResponse
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.repositories.table_repositories import ClaudeCodePluginRepository
|
||||
from litellm.types.proxy.claude_code_endpoints import (
|
||||
ListPluginsResponse,
|
||||
PluginListItem,
|
||||
@@ -71,7 +72,7 @@ async def get_marketplace():
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many(
|
||||
plugins = await ClaudeCodePluginRepository(prisma_client).table.find_many(
|
||||
where={"enabled": True}
|
||||
)
|
||||
|
||||
@@ -268,12 +269,12 @@ async def register_plugin(
|
||||
manifest["namespace"] = request.namespace
|
||||
|
||||
# Check if plugin exists
|
||||
existing = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
|
||||
existing = await ClaudeCodePluginRepository(prisma_client).table.find_unique(
|
||||
where={"name": request.name}
|
||||
)
|
||||
|
||||
if existing:
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.update(
|
||||
plugin = await ClaudeCodePluginRepository(prisma_client).table.update(
|
||||
where={"name": request.name},
|
||||
data={
|
||||
"version": request.version,
|
||||
@@ -285,7 +286,7 @@ async def register_plugin(
|
||||
)
|
||||
action = "updated"
|
||||
else:
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.create(
|
||||
plugin = await ClaudeCodePluginRepository(prisma_client).table.create(
|
||||
data={
|
||||
"name": request.name,
|
||||
"version": request.version,
|
||||
@@ -348,7 +349,7 @@ async def list_plugins(
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
where = {"enabled": True} if enabled_only else {}
|
||||
plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many(
|
||||
plugins = await ClaudeCodePluginRepository(prisma_client).table.find_many(
|
||||
where=where
|
||||
)
|
||||
|
||||
@@ -415,7 +416,7 @@ async def get_plugin(
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
|
||||
plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique(
|
||||
where={"name": plugin_name}
|
||||
)
|
||||
|
||||
@@ -471,7 +472,7 @@ async def enable_plugin(
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
|
||||
plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique(
|
||||
where={"name": plugin_name}
|
||||
)
|
||||
if not plugin:
|
||||
@@ -480,7 +481,7 @@ async def enable_plugin(
|
||||
detail={"error": f"Plugin '{plugin_name}' not found"},
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_claudecodeplugintable.update(
|
||||
await ClaudeCodePluginRepository(prisma_client).table.update(
|
||||
where={"name": plugin_name},
|
||||
data={"enabled": True, "updated_at": datetime.now(timezone.utc)},
|
||||
)
|
||||
@@ -516,7 +517,7 @@ async def disable_plugin(
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
|
||||
plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique(
|
||||
where={"name": plugin_name}
|
||||
)
|
||||
if not plugin:
|
||||
@@ -525,7 +526,7 @@ async def disable_plugin(
|
||||
detail={"error": f"Plugin '{plugin_name}' not found"},
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_claudecodeplugintable.update(
|
||||
await ClaudeCodePluginRepository(prisma_client).table.update(
|
||||
where={"name": plugin_name},
|
||||
data={"enabled": False, "updated_at": datetime.now(timezone.utc)},
|
||||
)
|
||||
@@ -561,7 +562,7 @@ async def delete_plugin(
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
|
||||
plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique(
|
||||
plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique(
|
||||
where={"name": plugin_name}
|
||||
)
|
||||
if not plugin:
|
||||
@@ -570,7 +571,7 @@ async def delete_plugin(
|
||||
detail={"error": f"Plugin '{plugin_name}' not found"},
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_claudecodeplugintable.delete(
|
||||
await ClaudeCodePluginRepository(prisma_client).table.delete(
|
||||
where={"name": plugin_name}
|
||||
)
|
||||
|
||||
|
||||
@@ -61,19 +61,33 @@ from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.common_utils.cache_pydantic_utils import CacheCodec
|
||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
_safe_get_request_headers,
|
||||
_safe_get_request_query_params,
|
||||
)
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
|
||||
from litellm.proxy.guardrails.tool_name_extraction import (
|
||||
TOOL_CAPABLE_CALL_TYPES,
|
||||
extract_request_tool_names,
|
||||
)
|
||||
from litellm.proxy.common_utils.cache_pydantic_utils import CacheCodec
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.route_llm_request import route_request
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
|
||||
from litellm.repositories.budget_repository import BudgetRepository
|
||||
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
|
||||
from litellm.repositories.organization_repository import OrganizationRepository
|
||||
from litellm.repositories.project_repository import ProjectRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
AccessGroupRepository,
|
||||
EndUserRepository,
|
||||
JWTKeyMappingRepository,
|
||||
ManagedVectorStoresRepository,
|
||||
TagRepository,
|
||||
TeamMembershipRepository,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.router import Router
|
||||
from litellm.utils import get_utc_datetime
|
||||
|
||||
@@ -957,7 +971,7 @@ async def get_default_end_user_budget(
|
||||
|
||||
# Fetch from database
|
||||
try:
|
||||
budget_record = await prisma_client.db.litellm_budgettable.find_unique(
|
||||
budget_record = await BudgetRepository(prisma_client).table.find_unique(
|
||||
where={"budget_id": litellm.max_end_user_budget_id}
|
||||
)
|
||||
|
||||
@@ -1016,7 +1030,7 @@ async def get_team_member_default_budget(
|
||||
return LiteLLM_BudgetTable(**cached_budget)
|
||||
|
||||
try:
|
||||
budget_record = await prisma_client.db.litellm_budgettable.find_unique(
|
||||
budget_record = await BudgetRepository(prisma_client).table.find_unique(
|
||||
where={"budget_id": budget_id}
|
||||
)
|
||||
|
||||
@@ -1175,7 +1189,7 @@ async def get_end_user_object(
|
||||
|
||||
# Fetch from database
|
||||
try:
|
||||
response = await prisma_client.db.litellm_endusertable.find_unique(
|
||||
response = await EndUserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": end_user_id},
|
||||
include={"litellm_budget_table": True, "object_permission": True},
|
||||
)
|
||||
@@ -1375,7 +1389,7 @@ async def get_tag_objects_batch(
|
||||
# Batch fetch uncached tags from DB in one query
|
||||
if uncached_tags:
|
||||
try:
|
||||
db_tags = await prisma_client.db.litellm_tagtable.find_many(
|
||||
db_tags = await TagRepository(prisma_client).table.find_many(
|
||||
where={"tag_name": {"in": uncached_tags}},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
@@ -1469,7 +1483,7 @@ async def get_team_membership(
|
||||
|
||||
# else, check db
|
||||
try:
|
||||
response = await prisma_client.db.litellm_teammembership.find_unique(
|
||||
response = await TeamMembershipRepository(prisma_client).table.find_unique(
|
||||
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
@@ -1622,7 +1636,7 @@ async def _get_fuzzy_user_object(
|
||||
|
||||
response = None
|
||||
if sso_user_id is not None:
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
response = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"sso_user_id": sso_user_id},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
@@ -1630,14 +1644,14 @@ async def _get_fuzzy_user_object(
|
||||
if response is None and user_email is not None:
|
||||
# Use case-insensitive query to handle emails with different casing
|
||||
# This matches the pattern used in _check_duplicate_user_email
|
||||
response = await prisma_client.db.litellm_usertable.find_first(
|
||||
response = await UserRepository(prisma_client).table.find_first(
|
||||
where={"user_email": {"equals": user_email, "mode": "insensitive"}},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
|
||||
if response is not None and sso_user_id is not None: # update sso_user_id
|
||||
asyncio.create_task( # background task to update user with sso id
|
||||
prisma_client.db.litellm_usertable.update(
|
||||
UserRepository(prisma_client).table.update(
|
||||
where={"user_id": response.user_id},
|
||||
data={"sso_user_id": sso_user_id},
|
||||
)
|
||||
@@ -1687,7 +1701,7 @@ async def get_user_object(
|
||||
)
|
||||
|
||||
if should_check_db:
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
response = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_id}, include={"organization_memberships": True}
|
||||
)
|
||||
|
||||
@@ -1711,7 +1725,7 @@ async def get_user_object(
|
||||
if litellm.default_internal_user_params is not None:
|
||||
new_user_params.update(litellm.default_internal_user_params)
|
||||
|
||||
response = await prisma_client.db.litellm_usertable.create(
|
||||
response = await UserRepository(prisma_client).table.create(
|
||||
data=new_user_params,
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
@@ -1860,7 +1874,7 @@ async def _delete_cache_key_object(
|
||||
async def _get_team_db_check(
|
||||
team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None
|
||||
):
|
||||
response = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
response = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
@@ -1882,7 +1896,7 @@ async def _get_team_db_check(
|
||||
|
||||
|
||||
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
|
||||
return await prisma_client.db.litellm_teamtable.find_unique(
|
||||
return await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
@@ -2111,7 +2125,7 @@ async def get_access_object(
|
||||
|
||||
# Not in cache - fetch from DB
|
||||
try:
|
||||
response = await prisma_client.db.litellm_accessgrouptable.find_unique(
|
||||
response = await AccessGroupRepository(prisma_client).table.find_unique(
|
||||
where={"access_group_id": access_group_id}
|
||||
)
|
||||
|
||||
@@ -2193,7 +2207,7 @@ async def get_team_object_by_alias(
|
||||
|
||||
# Query database by team_alias
|
||||
try:
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"team_alias": team_alias}
|
||||
)
|
||||
|
||||
@@ -2301,7 +2315,7 @@ async def get_org_object_by_alias(
|
||||
|
||||
# Query database by organization_alias
|
||||
try:
|
||||
orgs = await prisma_client.db.litellm_organizationtable.find_many(
|
||||
orgs = await OrganizationRepository(prisma_client).table.find_many(
|
||||
where={"organization_alias": org_alias}
|
||||
)
|
||||
|
||||
@@ -2526,7 +2540,7 @@ async def get_jwt_key_mapping_object(
|
||||
|
||||
Returns the hashed token (str) if a matching active mapping is found, else None.
|
||||
"""
|
||||
mapping = await prisma_client.db.litellm_jwtkeymapping.find_first(
|
||||
mapping = await JWTKeyMappingRepository(prisma_client).table.find_first(
|
||||
where={
|
||||
"jwt_claim_name": jwt_claim_name,
|
||||
"jwt_claim_value": jwt_claim_value,
|
||||
@@ -2659,7 +2673,7 @@ async def get_object_permission(
|
||||
|
||||
# else, check db
|
||||
try:
|
||||
response = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
response = await ObjectPermissionRepository(prisma_client).table.find_unique(
|
||||
where={"object_permission_id": object_permission_id}
|
||||
)
|
||||
|
||||
@@ -2715,7 +2729,7 @@ async def get_managed_vector_store_rows_by_uuids(
|
||||
if not cache_misses:
|
||||
return result
|
||||
|
||||
rows = await prisma_client.db.litellm_managedvectorstorestable.find_many(
|
||||
rows = await ManagedVectorStoresRepository(prisma_client).table.find_many(
|
||||
where={"vector_store_id": {"in": cache_misses}},
|
||||
take=len(cache_misses),
|
||||
)
|
||||
@@ -2790,7 +2804,7 @@ async def get_org_object(
|
||||
if include_budget_table:
|
||||
query_kwargs["include"] = {"litellm_budget_table": True}
|
||||
|
||||
response = await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
response = await OrganizationRepository(prisma_client).table.find_unique(
|
||||
**query_kwargs
|
||||
)
|
||||
|
||||
@@ -4180,7 +4194,7 @@ async def get_project_object(
|
||||
return deserialized_project
|
||||
|
||||
# Fetch from DB
|
||||
project_row = await prisma_client.db.litellm_projecttable.find_unique(
|
||||
project_row = await ProjectRepository(prisma_client).table.find_unique(
|
||||
where={"project_id": project_id},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
@@ -4480,10 +4494,10 @@ async def vector_store_access_check(
|
||||
#########################################################
|
||||
# Check if the key can access the vector store
|
||||
if valid_token is not None and valid_token.object_permission_id is not None:
|
||||
key_object_permission = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": valid_token.object_permission_id},
|
||||
)
|
||||
key_object_permission = await ObjectPermissionRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={"object_permission_id": valid_token.object_permission_id},
|
||||
)
|
||||
if key_object_permission is not None:
|
||||
_can_object_call_vector_stores(
|
||||
@@ -4494,10 +4508,10 @@ async def vector_store_access_check(
|
||||
|
||||
# Check if the team can access the vector store
|
||||
if team_object is not None and team_object.object_permission_id is not None:
|
||||
team_object_permission = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": team_object.object_permission_id},
|
||||
)
|
||||
team_object_permission = await ObjectPermissionRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={"object_permission_id": team_object.object_permission_id},
|
||||
)
|
||||
if team_object_permission is not None:
|
||||
_can_object_call_vector_stores(
|
||||
|
||||
@@ -14,11 +14,11 @@ import os
|
||||
import re
|
||||
from typing import Any, List, Literal, Optional, Set, Tuple, Union, cast
|
||||
|
||||
import jwt
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from fastapi import HTTPException, status
|
||||
import jwt
|
||||
from jwt.api_jwk import PyJWK
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
@@ -50,6 +50,7 @@ from litellm.proxy.auth.auth_checks import can_team_access_model
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
|
||||
from .auth_checks import (
|
||||
_allowed_routes_check,
|
||||
@@ -1790,7 +1791,7 @@ class JWTAuthManager:
|
||||
# Update user role
|
||||
new_role = jwt_handler.map_jwt_role_to_litellm_role(jwt_valid_token)
|
||||
if new_role and user_object.user_role != new_role.value:
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
await UserRepository(prisma_client).table.update(
|
||||
where={"user_id": user_object.user_id},
|
||||
data={"user_role": new_role.value},
|
||||
)
|
||||
|
||||
@@ -34,6 +34,7 @@ from litellm.proxy.utils import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
|
||||
|
||||
@@ -45,7 +46,7 @@ async def _rehash_password_if_needed(user_id: str, password: str, stored: str) -
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is not None:
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
await UserRepository(prisma_client).table.update(
|
||||
where={"user_id": user_id},
|
||||
data={"password": hash_password(password)},
|
||||
)
|
||||
@@ -151,7 +152,7 @@ async def authenticate_user( # noqa: PLR0915
|
||||
if prisma_client is not None:
|
||||
_user_row = cast(
|
||||
Optional[LiteLLM_UserTable],
|
||||
await prisma_client.db.litellm_usertable.find_first(
|
||||
await UserRepository(prisma_client).table.find_first(
|
||||
where={"user_email": {"equals": username, "mode": "insensitive"}}
|
||||
),
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
|
||||
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
|
||||
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
|
||||
from litellm.router import Router
|
||||
from litellm.router_utils.fallback_event_handlers import get_fallback_model_group
|
||||
from litellm.types.router import CredentialLiteLLMParams, LiteLLM_Params
|
||||
@@ -86,7 +87,7 @@ async def get_mcp_server_ids(
|
||||
|
||||
# Make a direct SQL query to get just the mcp_servers
|
||||
try:
|
||||
result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
result = await ObjectPermissionRepository(prisma_client).table.find_unique(
|
||||
where={"object_permission_id": user_api_key_dict.object_permission_id},
|
||||
)
|
||||
if result and result.mcp_servers:
|
||||
|
||||
@@ -21,9 +21,9 @@ from fastapi.security.api_key import APIKeyHeader
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
from litellm.constants import LITELLM_PROXY_MASTER_KEY_ALIAS
|
||||
from litellm.integrations.otel.model.config import is_otel_v2_enabled
|
||||
from litellm.integrations.otel.runtime import phase_span, seed_request_identity
|
||||
from litellm.constants import LITELLM_PROXY_MASTER_KEY_ALIAS
|
||||
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
|
||||
from litellm.proxy._types import *
|
||||
@@ -65,7 +65,6 @@ from litellm.proxy.auth.oauth2_check import Oauth2Handler
|
||||
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.common_utils.cache_coordinator import EventDrivenCacheCoordinator
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
_read_request_body,
|
||||
_safe_get_request_headers,
|
||||
@@ -73,12 +72,14 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
populate_request_with_path_params,
|
||||
)
|
||||
from litellm.proxy.common_utils.realtime_utils import _realtime_request_body
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
ProxyLogging,
|
||||
normalize_route_for_root_path,
|
||||
)
|
||||
from litellm.repositories.table_repositories import TeamMembershipRepository
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
@@ -1797,7 +1798,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
_team_id = valid_token.team_id
|
||||
|
||||
if _user_id is not None and _team_id is not None:
|
||||
_db_member = await prisma_client.db.litellm_teammembership.find_first(
|
||||
_db_member = await TeamMembershipRepository(
|
||||
prisma_client
|
||||
).table.find_first(
|
||||
where={
|
||||
"user_id": _user_id,
|
||||
"team_id": _team_id,
|
||||
|
||||
@@ -8,7 +8,6 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.constants import (
|
||||
EXPIRED_UI_SESSION_KEY_CLEANUP_JOB_NAME,
|
||||
LITELLM_EXPIRED_UI_SESSION_KEY_CLEANUP_BATCH_SIZE,
|
||||
@@ -16,11 +15,15 @@ from litellm.constants import (
|
||||
UI_SESSION_TOKEN_TEAM_ID,
|
||||
)
|
||||
from litellm.proxy._types import KeyRequest, LiteLLM_VerificationToken, UserAPIKeyAuth
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
delete_verification_tokens,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
|
||||
|
||||
class ExpiredUISessionKeyCleanupManager:
|
||||
@@ -147,7 +150,7 @@ class ExpiredUISessionKeyCleanupManager:
|
||||
Find expired LiteLLM dashboard session keys.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
return await self.prisma_client.db.litellm_verificationtoken.find_many(
|
||||
return await VerificationTokenRepository(self.prisma_client).table.find_many(
|
||||
where={
|
||||
"team_id": UI_SESSION_TOKEN_TEAM_ID,
|
||||
"expires": {"lt": now},
|
||||
|
||||
@@ -24,6 +24,12 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
regenerate_key_fn,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.table_repositories import (
|
||||
DeprecatedVerificationTokenRepository,
|
||||
)
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
|
||||
|
||||
class KeyRotationManager:
|
||||
@@ -124,20 +130,20 @@ class KeyRotationManager:
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
keys_with_rotation = (
|
||||
await self.prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={
|
||||
"auto_rotate": True, # Only keys marked for auto rotation
|
||||
"OR": [
|
||||
{
|
||||
"key_rotation_at": None
|
||||
}, # Keys that need initial rotation time setup
|
||||
{
|
||||
"key_rotation_at": {"lte": now}
|
||||
}, # Keys where rotation time has passed
|
||||
],
|
||||
}
|
||||
)
|
||||
keys_with_rotation = await VerificationTokenRepository(
|
||||
self.prisma_client
|
||||
).table.find_many(
|
||||
where={
|
||||
"auto_rotate": True, # Only keys marked for auto rotation
|
||||
"OR": [
|
||||
{
|
||||
"key_rotation_at": None
|
||||
}, # Keys that need initial rotation time setup
|
||||
{
|
||||
"key_rotation_at": {"lte": now}
|
||||
}, # Keys where rotation time has passed
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return keys_with_rotation
|
||||
@@ -148,9 +154,9 @@ class KeyRotationManager:
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
result = await self.prisma_client.db.litellm_deprecatedverificationtoken.delete_many(
|
||||
where={"revoke_at": {"lt": now}}
|
||||
)
|
||||
result = await DeprecatedVerificationTokenRepository(
|
||||
self.prisma_client
|
||||
).table.delete_many(where={"revoke_at": {"lt": now}})
|
||||
if result > 0:
|
||||
verbose_proxy_logger.debug(
|
||||
"Cleaned up %s expired deprecated key(s)", result
|
||||
@@ -206,7 +212,7 @@ class KeyRotationManager:
|
||||
# Calculate next rotation time using helper function
|
||||
now = datetime.now(timezone.utc)
|
||||
next_rotation_time = _calculate_key_rotation_time(key.rotation_interval)
|
||||
await self.prisma_client.db.litellm_verificationtoken.update(
|
||||
await VerificationTokenRepository(self.prisma_client).table.update(
|
||||
where={"token": response.token_id},
|
||||
data={
|
||||
"rotation_count": (key.rotation_count or 0) + 1,
|
||||
|
||||
@@ -14,6 +14,16 @@ from litellm.proxy._types import (
|
||||
LiteLLM_VerificationToken,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
from litellm.repositories.organization_repository import OrganizationRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
EndUserRepository,
|
||||
TagRepository,
|
||||
TeamMembershipRepository,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
||||
@@ -159,7 +169,7 @@ class ResetBudgetJob:
|
||||
"""
|
||||
return await self._cascade_reset_spend_for_budget_link(
|
||||
budgets_to_reset=budgets_to_reset,
|
||||
table=self.prisma_client.db.litellm_teammembership,
|
||||
table=TeamMembershipRepository(self.prisma_client).table,
|
||||
counter_key_fn=lambda m: f"spend:team_member:{m.user_id}:{m.team_id}",
|
||||
log_subject="team memberships",
|
||||
cache_key_fn=lambda m: f"{m.team_id}_{m.user_id}",
|
||||
@@ -176,7 +186,7 @@ class ResetBudgetJob:
|
||||
"""
|
||||
return await self._cascade_reset_spend_for_budget_link(
|
||||
budgets_to_reset=budgets_to_reset,
|
||||
table=self.prisma_client.db.litellm_verificationtoken,
|
||||
table=VerificationTokenRepository(self.prisma_client).table,
|
||||
counter_key_fn=lambda k: f"spend:key:{k.token}",
|
||||
log_subject="keys",
|
||||
extra_where={"budget_duration": None, "spend": {"gt": 0}},
|
||||
@@ -191,7 +201,7 @@ class ResetBudgetJob:
|
||||
"""
|
||||
return await self._cascade_reset_spend_for_budget_link(
|
||||
budgets_to_reset=budgets_to_reset,
|
||||
table=self.prisma_client.db.litellm_organizationtable,
|
||||
table=OrganizationRepository(self.prisma_client).table,
|
||||
counter_key_fn=lambda o: f"spend:org:{o.organization_id}",
|
||||
log_subject="orgs",
|
||||
extra_where={"spend": {"gt": 0}},
|
||||
@@ -217,7 +227,7 @@ class ResetBudgetJob:
|
||||
"""
|
||||
return await self._cascade_reset_spend_for_budget_link(
|
||||
budgets_to_reset=budgets_to_reset,
|
||||
table=self.prisma_client.db.litellm_tagtable,
|
||||
table=TagRepository(self.prisma_client).table,
|
||||
counter_key_fn=lambda t: f"spend:tag:{t.tag_name}",
|
||||
log_subject="tags",
|
||||
extra_where={"spend": {"gt": 0}},
|
||||
@@ -406,7 +416,7 @@ class ResetBudgetJob:
|
||||
rely on the default budget (litellm.max_end_user_budget_id) applied
|
||||
in-memory during auth checks.
|
||||
"""
|
||||
rows = await self.prisma_client.db.litellm_endusertable.find_many(
|
||||
rows = await EndUserRepository(self.prisma_client).table.find_many(
|
||||
where={
|
||||
"budget_id": None,
|
||||
"spend": {"gt": 0},
|
||||
@@ -824,7 +834,7 @@ class ResetBudgetJob:
|
||||
):
|
||||
changed = True
|
||||
if changed:
|
||||
await self.prisma_client.db.litellm_verificationtoken.update(
|
||||
await VerificationTokenRepository(self.prisma_client).table.update(
|
||||
where={"token": row["token"]},
|
||||
data={"budget_limits": json.dumps(windows)}, # type: ignore[arg-type]
|
||||
)
|
||||
@@ -852,7 +862,7 @@ class ResetBudgetJob:
|
||||
):
|
||||
changed = True
|
||||
if changed:
|
||||
await self.prisma_client.db.litellm_teamtable.update(
|
||||
await TeamRepository(self.prisma_client).table.update(
|
||||
where={"team_id": row["team_id"]},
|
||||
data={"budget_limits": json.dumps(windows)}, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from litellm.proxy.common_utils.resource_ownership import (
|
||||
is_proxy_admin,
|
||||
user_can_access_resource_owner,
|
||||
)
|
||||
from litellm.repositories.table_repositories import ManagedObjectRepository
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
|
||||
CONTAINER_OBJECT_PURPOSE = "container"
|
||||
@@ -213,7 +214,7 @@ async def record_container_owner(
|
||||
)
|
||||
return response
|
||||
|
||||
table = prisma_client.db.litellm_managedobjecttable
|
||||
table = ManagedObjectRepository(prisma_client).table
|
||||
existing = await table.find_unique(where={"model_object_id": model_object_id})
|
||||
if existing is not None:
|
||||
if getattr(existing, "file_purpose", None) != CONTAINER_OBJECT_PURPOSE:
|
||||
@@ -273,7 +274,7 @@ async def _get_container_owner(
|
||||
if prisma_client is None:
|
||||
return None
|
||||
|
||||
row = await prisma_client.db.litellm_managedobjecttable.find_first(
|
||||
row = await ManagedObjectRepository(prisma_client).table.find_first(
|
||||
where={
|
||||
"model_object_id": model_object_id,
|
||||
"file_purpose": CONTAINER_OBJECT_PURPOSE,
|
||||
@@ -319,7 +320,7 @@ async def _get_stored_container_id(
|
||||
if prisma_client is None:
|
||||
return None
|
||||
|
||||
row = await prisma_client.db.litellm_managedobjecttable.find_first(
|
||||
row = await ManagedObjectRepository(prisma_client).table.find_first(
|
||||
where={
|
||||
"model_object_id": model_object_id,
|
||||
"file_purpose": CONTAINER_OBJECT_PURPOSE,
|
||||
@@ -411,7 +412,7 @@ async def _get_allowed_container_ids(
|
||||
if prisma_client is None:
|
||||
return set()
|
||||
|
||||
rows = await prisma_client.db.litellm_managedobjecttable.find_many(
|
||||
rows = await ManagedObjectRepository(prisma_client).table.find_many(
|
||||
where={
|
||||
"file_purpose": CONTAINER_OBJECT_PURPOSE,
|
||||
"created_by": {"in": owner_scopes},
|
||||
|
||||
@@ -14,6 +14,7 @@ from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
||||
from litellm.proxy.utils import handle_exception_on_proxy, jsonify_object
|
||||
from litellm.repositories.credentials_repository import CredentialsRepository
|
||||
from litellm.types.utils import CreateCredentialItem, CredentialItem
|
||||
|
||||
router = APIRouter()
|
||||
@@ -96,7 +97,7 @@ async def create_credential(
|
||||
)
|
||||
credentials_dict = encrypted_credential.model_dump()
|
||||
credentials_dict_jsonified = jsonify_object(credentials_dict)
|
||||
await prisma_client.db.litellm_credentialstable.create(
|
||||
await CredentialsRepository(prisma_client).create(
|
||||
data={
|
||||
**credentials_dict_jsonified,
|
||||
"created_by": user_api_key_dict.user_id,
|
||||
@@ -245,9 +246,7 @@ async def delete_credential(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
await prisma_client.db.litellm_credentialstable.delete(
|
||||
where={"credential_name": credential_name}
|
||||
)
|
||||
await CredentialsRepository(prisma_client).delete_by_name(credential_name)
|
||||
|
||||
## DELETE FROM LITELLM ##
|
||||
litellm.credential_list = [
|
||||
@@ -326,15 +325,14 @@ async def update_credential(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
db_credential = await prisma_client.db.litellm_credentialstable.find_unique(
|
||||
where={"credential_name": credential_name},
|
||||
)
|
||||
credentials_repository = CredentialsRepository(prisma_client)
|
||||
db_credential = await credentials_repository.find_by_name(credential_name)
|
||||
if db_credential is None:
|
||||
raise HTTPException(status_code=404, detail="Credential not found in DB.")
|
||||
merged_credential = update_db_credential(db_credential, credential)
|
||||
credential_object_jsonified = jsonify_object(merged_credential.model_dump())
|
||||
await prisma_client.db.litellm_credentialstable.update(
|
||||
where={"credential_name": credential_name},
|
||||
await credentials_repository.update_by_name(
|
||||
credential_name,
|
||||
data={
|
||||
**credential_object_jsonified,
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
|
||||
@@ -20,6 +20,16 @@ from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.repositories.organization_repository import OrganizationRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
SpendLogsRepository,
|
||||
TeamMembershipRepository,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
@@ -83,25 +93,25 @@ class SpendCounterReseed:
|
||||
try:
|
||||
if counter_key.startswith("spend:key:"):
|
||||
token = counter_key[len("spend:key:") :]
|
||||
row = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": token}
|
||||
)
|
||||
row = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"token": token})
|
||||
elif counter_key.startswith("spend:team_member:"):
|
||||
suffix = counter_key[len("spend:team_member:") :]
|
||||
if ":" not in suffix:
|
||||
return None
|
||||
user_id, team_id = suffix.rsplit(":", 1)
|
||||
row = await prisma_client.db.litellm_teammembership.find_unique(
|
||||
row = await TeamMembershipRepository(prisma_client).table.find_unique(
|
||||
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}}
|
||||
)
|
||||
elif counter_key.startswith("spend:team:"):
|
||||
team_id = counter_key[len("spend:team:") :]
|
||||
row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
elif counter_key.startswith("spend:user:"):
|
||||
user_id = counter_key[len("spend:user:") :]
|
||||
row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
row = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
elif counter_key.startswith("spend:end_user:"):
|
||||
@@ -110,7 +120,7 @@ class SpendCounterReseed:
|
||||
return None
|
||||
elif counter_key.startswith("spend:org:"):
|
||||
org_id = counter_key[len("spend:org:") :]
|
||||
row = await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
row = await OrganizationRepository(prisma_client).table.find_unique(
|
||||
where={"organization_id": org_id}
|
||||
)
|
||||
else:
|
||||
@@ -243,7 +253,7 @@ class SpendCounterReseed:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = await prisma_client.db.litellm_spendlogs.group_by(
|
||||
response = await SpendLogsRepository(prisma_client).table.group_by(
|
||||
by=[group_field],
|
||||
where=where, # type: ignore[arg-type]
|
||||
sum={"spend": True},
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Set
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.table_repositories import SpendLogToolIndexRepository
|
||||
|
||||
|
||||
def _add_tool_calls_to_set(tool_calls: Any, out: Set[str]) -> None:
|
||||
@@ -141,7 +142,7 @@ async def process_spend_logs_tool_usage(
|
||||
}
|
||||
)
|
||||
if index_data:
|
||||
await prisma_client.db.litellm_spendlogtoolindex.create_many(
|
||||
await SpendLogToolIndexRepository(prisma_client).table.create_many(
|
||||
data=index_data,
|
||||
skip_duplicates=True,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ToolDiscoveryQueueItem
|
||||
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
|
||||
from litellm.repositories.table_repositories import ToolRepository
|
||||
from litellm.types.tool_management import (
|
||||
LiteLLM_ToolTableRow,
|
||||
ToolPolicyOverrideRow,
|
||||
@@ -84,7 +86,7 @@ async def batch_upsert_tools(
|
||||
if not data:
|
||||
return
|
||||
now = datetime.now(timezone.utc)
|
||||
table = prisma_client.db.litellm_tooltable
|
||||
table = ToolRepository(prisma_client).table
|
||||
for item in data:
|
||||
tool_name = item.get("tool_name", "")
|
||||
origin = item.get("origin") or "user_defined"
|
||||
@@ -134,7 +136,7 @@ async def list_tools(
|
||||
"""Return all tools, optionally filtered by input_policy."""
|
||||
try:
|
||||
where = {"input_policy": input_policy} if input_policy is not None else {}
|
||||
rows = await prisma_client.db.litellm_tooltable.find_many(
|
||||
rows = await ToolRepository(prisma_client).table.find_many(
|
||||
where=where,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
@@ -150,7 +152,7 @@ async def get_tool(
|
||||
) -> Optional[LiteLLM_ToolTableRow]:
|
||||
"""Return a single tool row by tool_name."""
|
||||
try:
|
||||
row = await prisma_client.db.litellm_tooltable.find_unique(
|
||||
row = await ToolRepository(prisma_client).table.find_unique(
|
||||
where={"tool_name": tool_name},
|
||||
)
|
||||
if row is None:
|
||||
@@ -192,7 +194,7 @@ async def update_tool_policy(
|
||||
if output_policy is not None:
|
||||
update_data["output_policy"] = output_policy
|
||||
|
||||
await prisma_client.db.litellm_tooltable.upsert(
|
||||
await ToolRepository(prisma_client).table.upsert(
|
||||
where={"tool_name": tool_name},
|
||||
data={
|
||||
"create": create_data,
|
||||
@@ -217,7 +219,7 @@ async def get_tools_by_names(
|
||||
if not tool_names:
|
||||
return {}
|
||||
try:
|
||||
rows = await prisma_client.db.litellm_tooltable.find_many(
|
||||
rows = await ToolRepository(prisma_client).table.find_many(
|
||||
where={"tool_name": {"in": tool_names}},
|
||||
)
|
||||
return {
|
||||
@@ -244,7 +246,7 @@ async def list_overrides_for_tool(
|
||||
"""
|
||||
out: List[ToolPolicyOverrideRow] = []
|
||||
try:
|
||||
perms = await prisma_client.db.litellm_objectpermissiontable.find_many(
|
||||
perms = await ObjectPermissionRepository(prisma_client).table.find_many(
|
||||
where={"blocked_tools": {"has": tool_name}},
|
||||
include={
|
||||
"verification_tokens": True,
|
||||
@@ -307,7 +309,7 @@ class ToolPolicyRegistry:
|
||||
async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None:
|
||||
"""Load all tool policies and object-permission blocked_tools from DB."""
|
||||
try:
|
||||
tools = await prisma_client.db.litellm_tooltable.find_many()
|
||||
tools = await ToolRepository(prisma_client).table.find_many()
|
||||
self._tool_input_policies = {
|
||||
row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted"
|
||||
for row in tools
|
||||
@@ -317,7 +319,7 @@ class ToolPolicyRegistry:
|
||||
for row in tools
|
||||
}
|
||||
|
||||
perms = await prisma_client.db.litellm_objectpermissiontable.find_many()
|
||||
perms = await ObjectPermissionRepository(prisma_client).table.find_many()
|
||||
self._blocked_tools_by_op_id = {}
|
||||
for row in perms:
|
||||
op_id = getattr(row, "object_permission_id", None)
|
||||
@@ -388,7 +390,7 @@ async def add_tool_to_object_permission_blocked(
|
||||
if not object_permission_id or not tool_name:
|
||||
return False
|
||||
try:
|
||||
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
row = await ObjectPermissionRepository(prisma_client).table.find_unique(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
)
|
||||
if row is None:
|
||||
@@ -397,7 +399,7 @@ async def add_tool_to_object_permission_blocked(
|
||||
if tool_name in current:
|
||||
return True
|
||||
current.append(tool_name)
|
||||
await prisma_client.db.litellm_objectpermissiontable.update(
|
||||
await ObjectPermissionRepository(prisma_client).table.update(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
data={"blocked_tools": current},
|
||||
)
|
||||
@@ -418,7 +420,7 @@ async def remove_tool_from_object_permission_blocked(
|
||||
if not object_permission_id or not tool_name:
|
||||
return False
|
||||
try:
|
||||
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
row = await ObjectPermissionRepository(prisma_client).table.find_unique(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
)
|
||||
if row is None:
|
||||
@@ -427,7 +429,7 @@ async def remove_tool_from_object_permission_blocked(
|
||||
if tool_name not in current:
|
||||
return False
|
||||
current = [t for t in current if t != tool_name]
|
||||
await prisma_client.db.litellm_objectpermissiontable.update(
|
||||
await ObjectPermissionRepository(prisma_client).table.update(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
data={"blocked_tools": current},
|
||||
)
|
||||
|
||||
@@ -13,21 +13,21 @@ from urllib.parse import urlparse
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.proxy.common_utils.path_utils import safe_join
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
|
||||
from litellm.proxy.common_utils.path_utils import safe_join
|
||||
from litellm.proxy.guardrails.guardrail_hooks.custom_code.sandbox import (
|
||||
build_sandbox_globals,
|
||||
compile_sandboxed,
|
||||
)
|
||||
from litellm.proxy.guardrails.guardrail_registry import GuardrailRegistry
|
||||
from litellm.proxy.guardrails.usage_endpoints import router as guardrails_usage_router
|
||||
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
|
||||
from litellm.repositories.table_repositories import GuardrailsRepository
|
||||
from litellm.types.guardrails import (
|
||||
PII_ENTITY_CATEGORIES_MAP,
|
||||
ApplyGuardrailRequest,
|
||||
@@ -373,7 +373,7 @@ async def create_guardrail(
|
||||
# Configuration error — roll back the DB write so the guardrail isn't orphaned
|
||||
if prisma_client is not None:
|
||||
try:
|
||||
await prisma_client.db.litellm_guardrailstable.delete(
|
||||
await GuardrailsRepository(prisma_client).table.delete(
|
||||
where={"guardrail_id": guardrail_id}
|
||||
)
|
||||
except Exception as rollback_err:
|
||||
@@ -705,7 +705,7 @@ async def register_guardrail(
|
||||
)
|
||||
|
||||
try:
|
||||
existing = await prisma_client.db.litellm_guardrailstable.find_unique(
|
||||
existing = await GuardrailsRepository(prisma_client).table.find_unique(
|
||||
where={"guardrail_name": request.guardrail_name}
|
||||
)
|
||||
if existing is not None:
|
||||
@@ -732,7 +732,7 @@ async def register_guardrail(
|
||||
guardrail_info_str = safe_dumps(guardrail_info)
|
||||
|
||||
try:
|
||||
created = await prisma_client.db.litellm_guardrailstable.create(
|
||||
created = await GuardrailsRepository(prisma_client).table.create(
|
||||
data={
|
||||
"guardrail_name": request.guardrail_name,
|
||||
"litellm_params": litellm_params_str,
|
||||
@@ -874,7 +874,7 @@ async def list_guardrail_submissions(
|
||||
where_clause["team_id"] = {"in": visible_team_ids}
|
||||
|
||||
# Single query: fetch team guardrails visible to the caller
|
||||
all_team_rows = await prisma_client.db.litellm_guardrailstable.find_many(
|
||||
all_team_rows = await GuardrailsRepository(prisma_client).table.find_many(
|
||||
where=where_clause,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
@@ -945,7 +945,7 @@ async def get_guardrail_submission(
|
||||
is_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
|
||||
try:
|
||||
row = await prisma_client.db.litellm_guardrailstable.find_unique(
|
||||
row = await GuardrailsRepository(prisma_client).table.find_unique(
|
||||
where={"guardrail_id": guardrail_id}
|
||||
)
|
||||
if row is None:
|
||||
@@ -986,7 +986,7 @@ async def approve_guardrail_submission(
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
row = await prisma_client.db.litellm_guardrailstable.find_unique(
|
||||
row = await GuardrailsRepository(prisma_client).table.find_unique(
|
||||
where={"guardrail_id": guardrail_id}
|
||||
)
|
||||
if row is None:
|
||||
@@ -1000,7 +1000,7 @@ async def approve_guardrail_submission(
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
await prisma_client.db.litellm_guardrailstable.update(
|
||||
await GuardrailsRepository(prisma_client).table.update(
|
||||
where={"guardrail_id": guardrail_id},
|
||||
data={"status": "active", "reviewed_at": now, "updated_at": now},
|
||||
)
|
||||
@@ -1072,7 +1072,7 @@ async def reject_guardrail_submission(
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
row = await prisma_client.db.litellm_guardrailstable.find_unique(
|
||||
row = await GuardrailsRepository(prisma_client).table.find_unique(
|
||||
where={"guardrail_id": guardrail_id}
|
||||
)
|
||||
if row is None:
|
||||
@@ -1086,7 +1086,7 @@ async def reject_guardrail_submission(
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
await prisma_client.db.litellm_guardrailstable.update(
|
||||
await GuardrailsRepository(prisma_client).table.update(
|
||||
where={"guardrail_id": guardrail_id},
|
||||
data={"status": "rejected", "reviewed_at": now, "updated_at": now},
|
||||
)
|
||||
@@ -2288,10 +2288,10 @@ async def apply_guardrail(
|
||||
"""
|
||||
import traceback
|
||||
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.litellm_core_utils.thread_pool_executor import (
|
||||
executor as thread_pool_executor,
|
||||
)
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
proxy_config,
|
||||
|
||||
@@ -11,12 +11,15 @@ from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy.guardrails.guardrail_hooks.grayswan import GraySwanGuardrail
|
||||
from litellm.proxy.guardrails.guardrail_hooks.grayswan import (
|
||||
GraySwanGuardrail,
|
||||
)
|
||||
from litellm.proxy.guardrails.guardrail_hooks.grayswan import (
|
||||
initialize_guardrail as initialize_grayswan,
|
||||
)
|
||||
from litellm.proxy.types_utils.utils import get_instance_fn
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.table_repositories import GuardrailsRepository
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.guardrails import (
|
||||
Guardrail,
|
||||
@@ -26,6 +29,9 @@ from litellm.types.guardrails import (
|
||||
SupportedGuardrailIntegrations,
|
||||
)
|
||||
|
||||
from .guardrail_hooks.llm_as_a_judge import (
|
||||
initialize_guardrail as initialize_llm_as_a_judge,
|
||||
)
|
||||
from .guardrail_initializers import (
|
||||
initialize_bedrock,
|
||||
initialize_hide_secrets,
|
||||
@@ -34,9 +40,6 @@ from .guardrail_initializers import (
|
||||
initialize_presidio,
|
||||
initialize_tool_permission,
|
||||
)
|
||||
from .guardrail_hooks.llm_as_a_judge import (
|
||||
initialize_guardrail as initialize_llm_as_a_judge,
|
||||
)
|
||||
|
||||
guardrail_initializer_registry = {
|
||||
SupportedGuardrailIntegrations.BEDROCK.value: initialize_bedrock,
|
||||
@@ -257,7 +260,7 @@ class GuardrailRegistry:
|
||||
guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {}))
|
||||
|
||||
# Create guardrail in DB
|
||||
created_guardrail = await prisma_client.db.litellm_guardrailstable.create(
|
||||
created_guardrail = await GuardrailsRepository(prisma_client).table.create(
|
||||
data={
|
||||
"guardrail_name": guardrail_name,
|
||||
"litellm_params": litellm_params,
|
||||
@@ -283,7 +286,7 @@ class GuardrailRegistry:
|
||||
"""
|
||||
try:
|
||||
# Delete from DB
|
||||
await prisma_client.db.litellm_guardrailstable.delete(
|
||||
await GuardrailsRepository(prisma_client).table.delete(
|
||||
where={"guardrail_id": guardrail_id}
|
||||
)
|
||||
|
||||
@@ -311,7 +314,7 @@ class GuardrailRegistry:
|
||||
guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {}))
|
||||
|
||||
# Update in DB
|
||||
updated_guardrail = await prisma_client.db.litellm_guardrailstable.update(
|
||||
updated_guardrail = await GuardrailsRepository(prisma_client).table.update(
|
||||
where={"guardrail_id": guardrail_id},
|
||||
data={
|
||||
"guardrail_name": guardrail_name,
|
||||
@@ -335,11 +338,11 @@ class GuardrailRegistry:
|
||||
Only rows with status == "active" are returned (pending_review and rejected are excluded).
|
||||
"""
|
||||
try:
|
||||
guardrails_from_db = (
|
||||
await prisma_client.db.litellm_guardrailstable.find_many(
|
||||
where={"status": "active"},
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
guardrails_from_db = await GuardrailsRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where={"status": "active"},
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
|
||||
guardrails: List[Guardrail] = []
|
||||
@@ -357,7 +360,7 @@ class GuardrailRegistry:
|
||||
Get a guardrail by its ID from the database
|
||||
"""
|
||||
try:
|
||||
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
|
||||
guardrail = await GuardrailsRepository(prisma_client).table.find_unique(
|
||||
where={"guardrail_id": guardrail_id}
|
||||
)
|
||||
|
||||
@@ -375,7 +378,7 @@ class GuardrailRegistry:
|
||||
Get a guardrail by its name from the database
|
||||
"""
|
||||
try:
|
||||
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
|
||||
guardrail = await GuardrailsRepository(prisma_client).table.find_unique(
|
||||
where={"guardrail_name": guardrail_name}
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,14 @@ from pydantic import BaseModel
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.repositories.table_repositories import (
|
||||
DailyGuardrailMetricsRepository,
|
||||
DailyPolicyMetricsRepository,
|
||||
GuardrailsRepository,
|
||||
PolicyRepository,
|
||||
SpendLogGuardrailIndexRepository,
|
||||
SpendLogsRepository,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -272,10 +280,10 @@ async def guardrails_usage_overview(
|
||||
|
||||
try:
|
||||
# Guardrails from DB
|
||||
guardrails = await prisma_client.db.litellm_guardrailstable.find_many()
|
||||
guardrails = await GuardrailsRepository(prisma_client).table.find_many()
|
||||
|
||||
# Daily metrics in range
|
||||
metrics = await prisma_client.db.litellm_dailyguardrailmetrics.find_many(
|
||||
metrics = await DailyGuardrailMetricsRepository(prisma_client).table.find_many(
|
||||
where={"date": {"gte": start, "lte": end}}
|
||||
)
|
||||
|
||||
@@ -283,9 +291,9 @@ async def guardrails_usage_overview(
|
||||
start_prev = (
|
||||
datetime.strptime(start, "%Y-%m-%d") - timedelta(days=7)
|
||||
).strftime("%Y-%m-%d")
|
||||
metrics_prev = await prisma_client.db.litellm_dailyguardrailmetrics.find_many(
|
||||
where={"date": {"gte": start_prev, "lt": start}}
|
||||
)
|
||||
metrics_prev = await DailyGuardrailMetricsRepository(
|
||||
prisma_client
|
||||
).table.find_many(where={"date": {"gte": start_prev, "lt": start}})
|
||||
|
||||
agg = _aggregate_daily_metrics(metrics, "guardrail_id")
|
||||
prev_agg = _prev_fail_rates(metrics_prev, "guardrail_id")
|
||||
@@ -335,7 +343,7 @@ async def guardrails_usage_detail(
|
||||
end = end_date or now.strftime("%Y-%m-%d")
|
||||
start = start_date or (now - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
|
||||
guardrail = await GuardrailsRepository(prisma_client).table.find_unique(
|
||||
where={"guardrail_id": guardrail_id}
|
||||
)
|
||||
if not guardrail:
|
||||
@@ -349,13 +357,13 @@ async def guardrails_usage_detail(
|
||||
)
|
||||
metric_ids = [i for i in (logical_id, guardrail_id) if i]
|
||||
|
||||
metrics = await prisma_client.db.litellm_dailyguardrailmetrics.find_many(
|
||||
metrics = await DailyGuardrailMetricsRepository(prisma_client).table.find_many(
|
||||
where={
|
||||
"guardrail_id": {"in": metric_ids},
|
||||
"date": {"gte": start, "lte": end},
|
||||
}
|
||||
)
|
||||
metrics_prev = await prisma_client.db.litellm_dailyguardrailmetrics.find_many(
|
||||
metrics_prev = await DailyGuardrailMetricsRepository(prisma_client).table.find_many(
|
||||
where={
|
||||
"guardrail_id": {"in": metric_ids},
|
||||
"date": {"lt": start},
|
||||
@@ -574,7 +582,7 @@ async def guardrails_usage_logs(
|
||||
# Query by both so we match regardless of which was written.
|
||||
effective_guardrail_ids: List[str] = [guardrail_id] if guardrail_id else []
|
||||
if guardrail_id:
|
||||
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
|
||||
guardrail = await GuardrailsRepository(prisma_client).table.find_unique(
|
||||
where={"guardrail_id": guardrail_id}
|
||||
)
|
||||
if guardrail:
|
||||
@@ -585,19 +593,23 @@ async def guardrails_usage_logs(
|
||||
where = _build_usage_logs_where(
|
||||
effective_guardrail_ids or None, policy_id, start_date, end_date
|
||||
)
|
||||
index_rows = await prisma_client.db.litellm_spendlogguardrailindex.find_many(
|
||||
index_rows = await SpendLogGuardrailIndexRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where=where,
|
||||
order={"start_time": "desc"},
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size + 1,
|
||||
)
|
||||
total = await prisma_client.db.litellm_spendlogguardrailindex.count(where=where)
|
||||
total = await SpendLogGuardrailIndexRepository(prisma_client).table.count(
|
||||
where=where
|
||||
)
|
||||
request_ids = [r.request_id for r in index_rows[:page_size]]
|
||||
if not request_ids:
|
||||
return UsageLogsResponse(
|
||||
logs=[], total=total, page=page, page_size=page_size
|
||||
)
|
||||
spend_logs = await prisma_client.db.litellm_spendlogs.find_many(
|
||||
spend_logs = await SpendLogsRepository(prisma_client).table.find_many(
|
||||
where={"request_id": {"in": request_ids}}
|
||||
)
|
||||
log_by_id = {s.request_id: s for s in spend_logs}
|
||||
@@ -645,11 +657,13 @@ async def policies_usage_overview(
|
||||
start = start_date or (now - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
policies = await prisma_client.db.litellm_policytable.find_many()
|
||||
metrics = await prisma_client.db.litellm_dailypolicymetrics.find_many(
|
||||
policies = await PolicyRepository(prisma_client).table.find_many()
|
||||
metrics = await DailyPolicyMetricsRepository(prisma_client).table.find_many(
|
||||
where={"date": {"gte": start, "lte": end}}
|
||||
)
|
||||
metrics_prev = await prisma_client.db.litellm_dailypolicymetrics.find_many(
|
||||
metrics_prev = await DailyPolicyMetricsRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where={
|
||||
"date": {
|
||||
"gte": (
|
||||
|
||||
@@ -10,6 +10,10 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.table_repositories import (
|
||||
DailyGuardrailMetricsRepository,
|
||||
SpendLogGuardrailIndexRepository,
|
||||
)
|
||||
|
||||
|
||||
def _guardrail_status_to_action(status: Optional[str]) -> str:
|
||||
@@ -132,7 +136,7 @@ async def process_spend_logs_guardrail_usage(
|
||||
}
|
||||
)
|
||||
try:
|
||||
await prisma_client.db.litellm_spendlogguardrailindex.create_many(
|
||||
await SpendLogGuardrailIndexRepository(prisma_client).table.create_many(
|
||||
data=index_data,
|
||||
skip_duplicates=True,
|
||||
)
|
||||
@@ -146,7 +150,7 @@ async def process_spend_logs_guardrail_usage(
|
||||
n = int(agg["requests_evaluated"])
|
||||
if n == 0:
|
||||
continue
|
||||
await prisma_client.db.litellm_dailyguardrailmetrics.upsert(
|
||||
await DailyGuardrailMetricsRepository(prisma_client).table.upsert(
|
||||
where={
|
||||
"guardrail_id_date": {
|
||||
"guardrail_id": guardrail_id,
|
||||
|
||||
@@ -3,7 +3,6 @@ Hooks that are triggered when a litellm user event occurs
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from litellm._uuid import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
@@ -11,6 +10,7 @@ from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy._types import (
|
||||
AUDIT_ACTIONS,
|
||||
CommonProxyErrors,
|
||||
@@ -24,6 +24,7 @@ from litellm.proxy._types import (
|
||||
WebhookEvent,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
|
||||
|
||||
class UserManagementEventHooks:
|
||||
@@ -57,7 +58,7 @@ class UserManagementEventHooks:
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise Exception(CommonProxyErrors.db_not_connected_error.value)
|
||||
user_row: BaseModel = await prisma_client.db.litellm_usertable.find_first(
|
||||
user_row: BaseModel = await UserRepository(prisma_client).table.find_first(
|
||||
where={"user_id": response.user_id}
|
||||
)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
|
||||
from litellm.proxy.utils import get_prisma_client_or_throw
|
||||
from litellm.repositories.table_repositories import AccessGroupRepository
|
||||
from litellm.types.access_group import (
|
||||
AccessGroupCreateRequest,
|
||||
AccessGroupResponse,
|
||||
@@ -386,7 +387,7 @@ async def list_access_groups(
|
||||
CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
records = await prisma_client.db.litellm_accessgrouptable.find_many(
|
||||
records = await AccessGroupRepository(prisma_client).table.find_many(
|
||||
order={"created_at": "desc"}
|
||||
)
|
||||
return [_record_to_response(r) for r in records]
|
||||
@@ -405,7 +406,7 @@ async def get_access_group(
|
||||
CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
record = await prisma_client.db.litellm_accessgrouptable.find_unique(
|
||||
record = await AccessGroupRepository(prisma_client).table.find_unique(
|
||||
where={"access_group_id": access_group_id}
|
||||
)
|
||||
if record is None:
|
||||
|
||||
@@ -16,11 +16,12 @@ import math
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
|
||||
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
|
||||
from litellm.proxy.utils import jsonify_object
|
||||
from litellm.repositories.budget_repository import BudgetRepository
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -98,7 +99,7 @@ async def new_budget(
|
||||
budget_obj_json = budget_obj.model_dump(exclude_none=True)
|
||||
budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries
|
||||
try:
|
||||
response = await prisma_client.db.litellm_budgettable.create(
|
||||
response = await BudgetRepository(prisma_client).table.create(
|
||||
data={
|
||||
**budget_obj_jsonified, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
@@ -182,7 +183,7 @@ async def update_budget(
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail={"error": str(e)})
|
||||
|
||||
response = await prisma_client.db.litellm_budgettable.update(
|
||||
response = await BudgetRepository(prisma_client).table.update(
|
||||
where={"budget_id": budget_obj.budget_id},
|
||||
data={
|
||||
**budget_obj.model_dump(exclude_unset=True), # type: ignore
|
||||
@@ -217,7 +218,7 @@ async def info_budget(data: BudgetRequest):
|
||||
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
|
||||
},
|
||||
)
|
||||
response = await prisma_client.db.litellm_budgettable.find_many(
|
||||
response = await BudgetRepository(prisma_client).table.find_many(
|
||||
where={"budget_id": {"in": data.budgets}},
|
||||
)
|
||||
|
||||
@@ -261,7 +262,7 @@ async def budget_settings(
|
||||
)
|
||||
|
||||
## get budget item from db
|
||||
db_budget_row = await prisma_client.db.litellm_budgettable.find_first(
|
||||
db_budget_row = await BudgetRepository(prisma_client).table.find_first(
|
||||
where={"budget_id": budget_id}
|
||||
)
|
||||
|
||||
@@ -327,7 +328,7 @@ async def list_budget(
|
||||
},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_budgettable.find_many()
|
||||
response = await BudgetRepository(prisma_client).table.find_many()
|
||||
|
||||
return response
|
||||
|
||||
@@ -366,7 +367,7 @@ async def delete_budget(
|
||||
},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_budgettable.delete(
|
||||
response = await BudgetRepository(prisma_client).table.delete(
|
||||
where={"budget_id": data.id}
|
||||
)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.repositories.table_repositories import CacheConfigRepository
|
||||
from litellm.types.management_endpoints import (
|
||||
CACHE_SETTINGS_FIELDS,
|
||||
REDIS_TYPE_DESCRIPTIONS,
|
||||
@@ -159,7 +160,7 @@ class CacheSettingsManager:
|
||||
import json
|
||||
|
||||
try:
|
||||
cache_config = await prisma_client.db.litellm_cacheconfig.find_unique(
|
||||
cache_config = await CacheConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "cache_config"}
|
||||
)
|
||||
if cache_config is not None and cache_config.cache_settings:
|
||||
@@ -274,7 +275,7 @@ async def get_cache_settings(
|
||||
# Try to get cache settings from database
|
||||
current_values = {}
|
||||
if prisma_client is not None:
|
||||
cache_config = await prisma_client.db.litellm_cacheconfig.find_unique(
|
||||
cache_config = await CacheConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "cache_config"}
|
||||
)
|
||||
if cache_config is not None and cache_config.cache_settings:
|
||||
@@ -417,7 +418,7 @@ async def update_cache_settings(
|
||||
|
||||
# Snapshot the prior settings (key set only — values get redacted in
|
||||
# the audit row) so the audit-log entry shows which fields changed.
|
||||
existing_row = await prisma_client.db.litellm_cacheconfig.find_unique(
|
||||
existing_row = await CacheConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "cache_config"}
|
||||
)
|
||||
before_settings: Optional[Dict[str, Any]] = None
|
||||
@@ -434,7 +435,7 @@ async def update_cache_settings(
|
||||
)
|
||||
|
||||
# Save to database
|
||||
await prisma_client.db.litellm_cacheconfig.upsert(
|
||||
await CacheConfigRepository(prisma_client).table.upsert(
|
||||
where={"id": "cache_config"},
|
||||
data={
|
||||
"create": {
|
||||
|
||||
@@ -8,6 +8,10 @@ from fastapi import HTTPException, status
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.table_repositories import DeletedVerificationTokenRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
BreakdownMetrics,
|
||||
DailySpendData,
|
||||
@@ -346,7 +350,7 @@ async def get_api_key_metadata(
|
||||
This ensures that key_alias and team_id are preserved in historical activity logs
|
||||
even after a key is deleted or regenerated.
|
||||
"""
|
||||
key_records = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
key_records = await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where={"token": {"in": list(api_keys)}}
|
||||
)
|
||||
result = {
|
||||
@@ -357,11 +361,11 @@ async def get_api_key_metadata(
|
||||
missing_keys = api_keys - set(result.keys())
|
||||
if missing_keys:
|
||||
try:
|
||||
deleted_key_records = (
|
||||
await prisma_client.db.litellm_deletedverificationtoken.find_many(
|
||||
where={"token": {"in": list(missing_keys)}},
|
||||
order={"deleted_at": "desc"},
|
||||
)
|
||||
deleted_key_records = await DeletedVerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where={"token": {"in": list(missing_keys)}},
|
||||
order={"deleted_at": "desc"},
|
||||
)
|
||||
# Use the most recent deleted record for each token (ordered by deleted_at desc)
|
||||
for k in deleted_key_records:
|
||||
|
||||
@@ -17,10 +17,13 @@ from litellm.proxy._types import (
|
||||
NewProjectRequest,
|
||||
UpdateProjectRequest,
|
||||
UserAPIKeyAuth,
|
||||
user_api_key_has_admin_view as _user_has_admin_view, # noqa: F401 re-exported
|
||||
)
|
||||
from litellm.proxy._types import ( # noqa: F401 re-exported
|
||||
user_api_key_has_admin_view as _user_has_admin_view,
|
||||
)
|
||||
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import NewProjectRequest, UpdateProjectRequest
|
||||
@@ -205,7 +208,7 @@ async def _user_has_admin_privileges(
|
||||
# Check if user is team admin for any team
|
||||
if user_obj.teams is not None and len(user_obj.teams) > 0:
|
||||
# Get all teams user is in
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"team_id": {"in": user_obj.teams}}
|
||||
)
|
||||
|
||||
@@ -282,7 +285,7 @@ async def _team_admin_can_invite_user(
|
||||
if not target_user_obj.teams or len(target_user_obj.teams) == 0:
|
||||
return False
|
||||
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"team_id": {"in": admin_user_obj.teams}}
|
||||
)
|
||||
admin_team_ids = [
|
||||
|
||||
@@ -30,6 +30,7 @@ from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.repositories.table_repositories import ConfigOverridesRepository
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.proxy.management_endpoints.config_overrides import (
|
||||
ConfigOverrideSettingsResponse,
|
||||
@@ -254,7 +255,7 @@ async def update_hashicorp_vault_config(
|
||||
|
||||
# Merge ALL fields the user didn't send: try DB first, fall back to env vars.
|
||||
# Omitted field = keep existing; empty string = clear/remove the field.
|
||||
existing_record = await prisma_client.db.litellm_configoverrides.find_unique(
|
||||
existing_record = await ConfigOverridesRepository(prisma_client).table.find_unique(
|
||||
where={"config_type": "hashicorp_vault"}
|
||||
)
|
||||
existing_decrypted: Optional[Dict[str, Any]] = None
|
||||
@@ -321,7 +322,7 @@ async def update_hashicorp_vault_config(
|
||||
# Only persist to DB after successful init
|
||||
encrypted_data = proxy_config._encrypt_env_variables(config_data)
|
||||
config_value = safe_dumps(encrypted_data)
|
||||
await prisma_client.db.litellm_configoverrides.upsert(
|
||||
await ConfigOverridesRepository(prisma_client).table.upsert(
|
||||
where={"config_type": "hashicorp_vault"},
|
||||
data={
|
||||
"create": {
|
||||
@@ -391,7 +392,7 @@ async def get_hashicorp_vault_config(
|
||||
field_schema = _build_field_schema(HashicorpVaultConfig)
|
||||
|
||||
# Try to load from DB
|
||||
db_record = await prisma_client.db.litellm_configoverrides.find_unique(
|
||||
db_record = await ConfigOverridesRepository(prisma_client).table.find_unique(
|
||||
where={"config_type": "hashicorp_vault"}
|
||||
)
|
||||
|
||||
@@ -448,7 +449,7 @@ async def delete_hashicorp_vault_config(
|
||||
|
||||
# Capture the prior config before delete so the audit-log row can
|
||||
# show *what* was removed (keys only — values get redacted).
|
||||
existing_record = await prisma_client.db.litellm_configoverrides.find_unique(
|
||||
existing_record = await ConfigOverridesRepository(prisma_client).table.find_unique(
|
||||
where={"config_type": "hashicorp_vault"}
|
||||
)
|
||||
before_config: Optional[Dict[str, Any]] = None
|
||||
@@ -463,7 +464,7 @@ async def delete_hashicorp_vault_config(
|
||||
# Delete DB record if it exists — ignore if not found
|
||||
deleted = False
|
||||
try:
|
||||
await prisma_client.db.litellm_configoverrides.delete(
|
||||
await ConfigOverridesRepository(prisma_client).table.delete(
|
||||
where={"config_type": "hashicorp_vault"}
|
||||
)
|
||||
deleted = True
|
||||
|
||||
@@ -17,8 +17,8 @@ import fastapi
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
|
||||
@@ -27,6 +27,8 @@ from litellm.proxy.management_helpers.object_permission_utils import (
|
||||
handle_update_object_permission_common,
|
||||
)
|
||||
from litellm.proxy.utils import handle_exception_on_proxy
|
||||
from litellm.repositories.budget_repository import BudgetRepository
|
||||
from litellm.repositories.table_repositories import EndUserRepository
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
@@ -68,7 +70,7 @@ async def block_user(data: BlockUsers):
|
||||
records = []
|
||||
if prisma_client is not None:
|
||||
for id in data.user_ids:
|
||||
record = await prisma_client.db.litellm_endusertable.upsert(
|
||||
record = await EndUserRepository(prisma_client).table.upsert(
|
||||
where={"user_id": id}, # type: ignore
|
||||
data={
|
||||
"create": {"user_id": id, "blocked": True}, # type: ignore
|
||||
@@ -337,7 +339,7 @@ async def new_end_user(
|
||||
_new_budget = new_budget_request(data)
|
||||
if _new_budget is not None:
|
||||
try:
|
||||
budget_record = await prisma_client.db.litellm_budgettable.create(
|
||||
budget_record = await BudgetRepository(prisma_client).table.create(
|
||||
data={
|
||||
**_new_budget.model_dump(exclude_unset=True),
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
|
||||
@@ -373,7 +375,7 @@ async def new_end_user(
|
||||
new_end_user_obj.pop("object_permission", None)
|
||||
|
||||
## WRITE TO DB ##
|
||||
end_user_record = await prisma_client.db.litellm_endusertable.create(
|
||||
end_user_record = await EndUserRepository(prisma_client).table.create(
|
||||
data=new_end_user_obj, # type: ignore
|
||||
include={"litellm_budget_table": True, "object_permission": True},
|
||||
)
|
||||
@@ -446,7 +448,7 @@ async def end_user_info(
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
user_info = await prisma_client.db.litellm_endusertable.find_first(
|
||||
user_info = await EndUserRepository(prisma_client).table.find_first(
|
||||
where={"user_id": end_user_id},
|
||||
include={"litellm_budget_table": True, "object_permission": True},
|
||||
)
|
||||
@@ -569,7 +571,7 @@ async def update_end_user(
|
||||
non_default_values[k] = v
|
||||
|
||||
## Get end user table data ##
|
||||
end_user_table_data = await prisma_client.db.litellm_endusertable.find_first(
|
||||
end_user_table_data = await EndUserRepository(prisma_client).table.find_first(
|
||||
where={"user_id": data.user_id}, include={"litellm_budget_table": True}
|
||||
)
|
||||
|
||||
@@ -613,17 +615,17 @@ async def update_end_user(
|
||||
if budget_table_data:
|
||||
if end_user_budget_table is None:
|
||||
## Create new budget ##
|
||||
budget_table_data_record = (
|
||||
await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**budget_table_data,
|
||||
"created_by": user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
},
|
||||
include={"end_users": True},
|
||||
)
|
||||
budget_table_data_record = await BudgetRepository(
|
||||
prisma_client
|
||||
).table.create(
|
||||
data={
|
||||
**budget_table_data,
|
||||
"created_by": user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
},
|
||||
include={"end_users": True},
|
||||
)
|
||||
|
||||
update_end_user_table_data["budget_id"] = (
|
||||
@@ -631,11 +633,11 @@ async def update_end_user(
|
||||
)
|
||||
else:
|
||||
## Update existing budget ##
|
||||
budget_table_data_record = (
|
||||
await prisma_client.db.litellm_budgettable.update(
|
||||
where={"budget_id": end_user_budget_table.budget_id},
|
||||
data=budget_table_data,
|
||||
)
|
||||
budget_table_data_record = await BudgetRepository(
|
||||
prisma_client
|
||||
).table.update(
|
||||
where={"budget_id": end_user_budget_table.budget_id},
|
||||
data=budget_table_data,
|
||||
)
|
||||
|
||||
## Update user table, with update params + new budget id (if set) ##
|
||||
@@ -652,7 +654,7 @@ async def update_end_user(
|
||||
if data.user_id is not None and len(data.user_id) > 0:
|
||||
update_end_user_table_data["user_id"] = data.user_id # type: ignore
|
||||
verbose_proxy_logger.debug("In update customer, user_id condition block.")
|
||||
response = await prisma_client.db.litellm_endusertable.update(
|
||||
response = await EndUserRepository(prisma_client).table.update(
|
||||
where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True, "object_permission": True} # type: ignore
|
||||
)
|
||||
if response is None:
|
||||
@@ -737,7 +739,7 @@ async def delete_end_user(
|
||||
and len(data.user_ids) > 0
|
||||
):
|
||||
# First check if all users exist
|
||||
existing_users = await prisma_client.db.litellm_endusertable.find_many(
|
||||
existing_users = await EndUserRepository(prisma_client).table.find_many(
|
||||
where={"user_id": {"in": data.user_ids}}
|
||||
)
|
||||
existing_user_ids = {user.user_id for user in existing_users}
|
||||
@@ -756,7 +758,7 @@ async def delete_end_user(
|
||||
)
|
||||
|
||||
# All users exist, proceed with deletion
|
||||
response = await prisma_client.db.litellm_endusertable.delete_many(
|
||||
response = await EndUserRepository(prisma_client).table.delete_many(
|
||||
where={"user_id": {"in": data.user_ids}}
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
@@ -828,7 +830,7 @@ async def list_end_user(
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_endusertable.find_many(
|
||||
response = await EndUserRepository(prisma_client).table.find_many(
|
||||
include={"litellm_budget_table": True, "object_permission": True}
|
||||
)
|
||||
|
||||
@@ -903,7 +905,7 @@ async def get_customer_daily_activity(
|
||||
where_condition = {}
|
||||
if end_user_ids_list:
|
||||
where_condition["user_id"] = {"in": list(end_user_ids_list)}
|
||||
end_user_aliases = await prisma_client.db.litellm_endusertable.find_many(
|
||||
end_user_aliases = await EndUserRepository(prisma_client).table.find_many(
|
||||
where=where_condition
|
||||
)
|
||||
end_user_alias_metadata = {e.user_id: {"alias": e.alias} for e in end_user_aliases}
|
||||
|
||||
@@ -27,6 +27,7 @@ else:
|
||||
# fastapi is only required for proxy, not for SDK usage
|
||||
pass
|
||||
|
||||
from litellm.repositories.config_repository import ConfigRepository
|
||||
from litellm.types.management_endpoints.router_settings_endpoints import (
|
||||
FallbackCreateRequest,
|
||||
FallbackDeleteResponse,
|
||||
@@ -157,7 +158,7 @@ async def create_fallback(
|
||||
|
||||
# Save to database - convert router_settings to JSON string
|
||||
router_settings_json = json.dumps(router_settings)
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "router_settings"},
|
||||
data={
|
||||
"create": {
|
||||
@@ -336,7 +337,7 @@ async def delete_fallback(
|
||||
|
||||
# Save to database - convert router_settings to JSON string
|
||||
router_settings_json = json.dumps(router_settings)
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "router_settings"},
|
||||
data={
|
||||
"create": {
|
||||
|
||||
@@ -43,6 +43,17 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.utils import handle_exception_on_proxy, hash_password
|
||||
from litellm.repositories.organization_repository import OrganizationRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
InvitationLinkRepository,
|
||||
OrganizationMembershipRepository,
|
||||
TeamMembershipRepository,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
@@ -154,7 +165,7 @@ async def _check_duplicate_user_field(
|
||||
if case_insensitive:
|
||||
where_clause[field_name]["mode"] = "insensitive"
|
||||
|
||||
existing_user = await prisma_client.db.litellm_usertable.find_first(
|
||||
existing_user = await UserRepository(prisma_client).table.find_first(
|
||||
where=where_clause
|
||||
)
|
||||
|
||||
@@ -434,7 +445,7 @@ async def new_user(
|
||||
await _check_duplicate_user_email(data.user_email, prisma_client)
|
||||
|
||||
# Check if license is over limit
|
||||
total_users = await prisma_client.db.litellm_usertable.count()
|
||||
total_users = await UserRepository(prisma_client).table.count()
|
||||
if total_users and _license_check.is_over_limit(total_users=total_users):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
@@ -851,7 +862,7 @@ async def _check_user_info_v2_access(
|
||||
|
||||
# Helper: fetch the target user row (reused across branches)
|
||||
async def _fetch_target_user():
|
||||
return await prisma_client.db.litellm_usertable.find_unique(
|
||||
return await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": target_user_id}
|
||||
)
|
||||
|
||||
@@ -866,7 +877,7 @@ async def _check_user_info_v2_access(
|
||||
# Rule 3: Team admins can look up users in their teams
|
||||
if user_api_key_dict.user_id is not None:
|
||||
# Get caller's teams
|
||||
caller_user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
caller_user = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id}
|
||||
)
|
||||
if caller_user is not None and caller_user.teams:
|
||||
@@ -876,7 +887,7 @@ async def _check_user_info_v2_access(
|
||||
return None
|
||||
|
||||
# Get all teams the caller belongs to
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"team_id": {"in": caller_user.teams}}
|
||||
)
|
||||
for team in teams:
|
||||
@@ -1165,7 +1176,7 @@ async def _schedule_user_update_audit_log(
|
||||
if prisma_client is None:
|
||||
return
|
||||
try:
|
||||
updated_user_row = await prisma_client.db.litellm_usertable.find_first(
|
||||
updated_user_row = await UserRepository(prisma_client).table.find_first(
|
||||
where={"user_id": response["user_id"]}
|
||||
)
|
||||
if updated_user_row:
|
||||
@@ -1255,11 +1266,11 @@ async def _update_single_user_helper(
|
||||
|
||||
existing_user_row: Optional[BaseModel] = None
|
||||
if user_request.user_id:
|
||||
existing_user_row = await prisma_client.db.litellm_usertable.find_first(
|
||||
existing_user_row = await UserRepository(prisma_client).table.find_first(
|
||||
where={"user_id": user_request.user_id}
|
||||
)
|
||||
elif user_request.user_email:
|
||||
existing_user_row = await prisma_client.db.litellm_usertable.find_first(
|
||||
existing_user_row = await UserRepository(prisma_client).table.find_first(
|
||||
where={"user_email": user_request.user_email}
|
||||
)
|
||||
|
||||
@@ -1640,7 +1651,7 @@ async def bulk_user_update(
|
||||
detail="Only proxy admins can update all users at once.",
|
||||
)
|
||||
# Optimized path for updating all users directly in database
|
||||
all_users_in_db = await prisma_client.db.litellm_usertable.find_many(
|
||||
all_users_in_db = await UserRepository(prisma_client).table.find_many(
|
||||
order={"created_at": "desc"}
|
||||
)
|
||||
|
||||
@@ -1676,7 +1687,7 @@ async def bulk_user_update(
|
||||
|
||||
try:
|
||||
# Perform bulk database update
|
||||
await prisma_client.db.litellm_usertable.update_many(
|
||||
await UserRepository(prisma_client).table.update_many(
|
||||
where={}, data=non_default_values # Update all users
|
||||
)
|
||||
|
||||
@@ -1783,7 +1794,7 @@ async def get_user_key_counts(
|
||||
|
||||
# Get count for each user_id individually
|
||||
for user_id in user_ids:
|
||||
count = await prisma_client.db.litellm_verificationtoken.count(
|
||||
count = await VerificationTokenRepository(prisma_client).table.count(
|
||||
where={
|
||||
"user_id": user_id,
|
||||
"OR": [
|
||||
@@ -2056,7 +2067,7 @@ async def get_users(
|
||||
else None
|
||||
)
|
||||
|
||||
users = await prisma_client.db.litellm_usertable.find_many(
|
||||
users = await UserRepository(prisma_client).table.find_many(
|
||||
where=where_conditions,
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
@@ -2066,7 +2077,9 @@ async def get_users(
|
||||
)
|
||||
|
||||
# Get total count of user rows
|
||||
total_count = await prisma_client.db.litellm_usertable.count(where=where_conditions)
|
||||
total_count = await UserRepository(prisma_client).table.count(
|
||||
where=where_conditions
|
||||
)
|
||||
|
||||
# Get key count for each user
|
||||
if users is not None:
|
||||
@@ -2137,14 +2150,14 @@ async def delete_user(
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
_cleanup_members_with_roles,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
get_audit_log_changed_by,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
create_audit_log_for_update,
|
||||
litellm_proxy_admin_name,
|
||||
prisma_client,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
get_audit_log_changed_by,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
@@ -2164,7 +2177,7 @@ async def delete_user(
|
||||
caller_admin_org_ids: set = set()
|
||||
if not caller_is_proxy_admin:
|
||||
caller_memberships = (
|
||||
await prisma_client.db.litellm_organizationmembership.find_many(
|
||||
await OrganizationMembershipRepository(prisma_client).table.find_many(
|
||||
where={
|
||||
"user_id": user_api_key_dict.user_id,
|
||||
"user_role": LitellmUserRoles.ORG_ADMIN.value,
|
||||
@@ -2188,11 +2201,9 @@ async def delete_user(
|
||||
# an N+1 DB call when delete_user is called with a large user_ids list.
|
||||
target_org_ids_by_user: Dict[str, set] = {}
|
||||
if not caller_is_proxy_admin:
|
||||
all_target_memberships = (
|
||||
await prisma_client.db.litellm_organizationmembership.find_many(
|
||||
where={"user_id": {"in": data.user_ids}}
|
||||
)
|
||||
)
|
||||
all_target_memberships = await OrganizationMembershipRepository(
|
||||
prisma_client
|
||||
).table.find_many(where={"user_id": {"in": data.user_ids}})
|
||||
for m in all_target_memberships:
|
||||
if not m.organization_id:
|
||||
continue
|
||||
@@ -2200,7 +2211,7 @@ async def delete_user(
|
||||
|
||||
# check that all teams passed exist
|
||||
for user_id in data.user_ids:
|
||||
user_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_row = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
@@ -2254,7 +2265,7 @@ async def delete_user(
|
||||
)
|
||||
|
||||
## CLEANUP MEMBERS_WITH_ROLES
|
||||
fetch_all_teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
fetch_all_teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"team_id": {"in": user_row.teams}}
|
||||
)
|
||||
teams_to_update = []
|
||||
@@ -2277,19 +2288,19 @@ async def delete_user(
|
||||
## update teams
|
||||
|
||||
for team in teams_to_update:
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": team.team_id},
|
||||
data={"members_with_roles": team.members_with_roles},
|
||||
)
|
||||
# End of Audit logging
|
||||
|
||||
## DELETE ASSOCIATED KEYS
|
||||
await prisma_client.db.litellm_verificationtoken.delete_many(
|
||||
await VerificationTokenRepository(prisma_client).table.delete_many(
|
||||
where={"user_id": {"in": data.user_ids}}
|
||||
)
|
||||
|
||||
## DELETE ASSOCIATED INVITATION LINKS
|
||||
await prisma_client.db.litellm_invitationlink.delete_many(
|
||||
await InvitationLinkRepository(prisma_client).table.delete_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"user_id": {"in": data.user_ids}},
|
||||
@@ -2300,17 +2311,17 @@ async def delete_user(
|
||||
)
|
||||
|
||||
## DELETE ASSOCIATED ORGANIZATION MEMBERSHIPS
|
||||
await prisma_client.db.litellm_organizationmembership.delete_many(
|
||||
await OrganizationMembershipRepository(prisma_client).table.delete_many(
|
||||
where={"user_id": {"in": data.user_ids}}
|
||||
)
|
||||
|
||||
## DELETE ASSOCIATED TEAM MEMBERSHIPS
|
||||
await prisma_client.db.litellm_teammembership.delete_many(
|
||||
await TeamMembershipRepository(prisma_client).table.delete_many(
|
||||
where={"user_id": {"in": data.user_ids}}
|
||||
)
|
||||
|
||||
## DELETE USERS
|
||||
deleted_users = await prisma_client.db.litellm_usertable.delete_many(
|
||||
deleted_users = await UserRepository(prisma_client).table.delete_many(
|
||||
where={"user_id": {"in": data.user_ids}}
|
||||
)
|
||||
|
||||
@@ -2340,16 +2351,18 @@ async def add_internal_user_to_organization(
|
||||
|
||||
try:
|
||||
# Check if organization_id exists
|
||||
organization_row = await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
where={"organization_id": organization_id}
|
||||
)
|
||||
organization_row = await OrganizationRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"organization_id": organization_id})
|
||||
if organization_row is None:
|
||||
raise Exception(
|
||||
f"Organization not found, passed organization_id={organization_id}"
|
||||
)
|
||||
|
||||
# Create a new organization membership entry
|
||||
new_membership = await prisma_client.db.litellm_organizationmembership.create(
|
||||
new_membership = await OrganizationMembershipRepository(
|
||||
prisma_client
|
||||
).table.create(
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"organization_id": organization_id,
|
||||
@@ -2559,13 +2572,13 @@ async def ui_view_users(
|
||||
}
|
||||
|
||||
# Query users with pagination and filters
|
||||
users: Optional[List[BaseModel]] = (
|
||||
await prisma_client.db.litellm_usertable.find_many(
|
||||
where=where_conditions,
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
users: Optional[List[BaseModel]] = await UserRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where=where_conditions,
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
|
||||
if not users:
|
||||
|
||||
@@ -11,6 +11,7 @@ from litellm.proxy._types import (
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
|
||||
from litellm.repositories.table_repositories import JWTKeyMappingRepository
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -61,7 +62,7 @@ async def create_jwt_key_mapping(
|
||||
if data.description is not None:
|
||||
create_data["description"] = data.description
|
||||
|
||||
new_mapping = await prisma_client.db.litellm_jwtkeymapping.create(
|
||||
new_mapping = await JWTKeyMappingRepository(prisma_client).table.create(
|
||||
data=create_data
|
||||
)
|
||||
|
||||
@@ -113,7 +114,7 @@ async def update_jwt_key_mapping(
|
||||
|
||||
try:
|
||||
# Get old mapping for cache invalidation
|
||||
old_mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
|
||||
old_mapping = await JWTKeyMappingRepository(prisma_client).table.find_unique(
|
||||
where={"id": data.id}
|
||||
)
|
||||
|
||||
@@ -123,7 +124,7 @@ async def update_jwt_key_mapping(
|
||||
cache_key = f"jwt_key_mapping:{old_mapping.jwt_claim_name}:{old_mapping.jwt_claim_value}"
|
||||
await user_api_key_cache.async_delete_cache(cache_key)
|
||||
|
||||
updated_mapping = await prisma_client.db.litellm_jwtkeymapping.update(
|
||||
updated_mapping = await JWTKeyMappingRepository(prisma_client).table.update(
|
||||
where={"id": data.id}, data=update_data
|
||||
)
|
||||
|
||||
@@ -166,7 +167,7 @@ async def delete_jwt_key_mapping(
|
||||
|
||||
try:
|
||||
# Get old mapping for cache invalidation
|
||||
old_mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
|
||||
old_mapping = await JWTKeyMappingRepository(prisma_client).table.find_unique(
|
||||
where={"id": data.id}
|
||||
)
|
||||
|
||||
@@ -176,7 +177,7 @@ async def delete_jwt_key_mapping(
|
||||
cache_key = f"jwt_key_mapping:{old_mapping.jwt_claim_name}:{old_mapping.jwt_claim_value}"
|
||||
await user_api_key_cache.async_delete_cache(cache_key)
|
||||
|
||||
await prisma_client.db.litellm_jwtkeymapping.delete(where={"id": data.id})
|
||||
await JWTKeyMappingRepository(prisma_client).table.delete(where={"id": data.id})
|
||||
return {"status": "success"}
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -206,12 +207,12 @@ async def list_jwt_key_mappings(
|
||||
|
||||
try:
|
||||
skip = (page - 1) * size
|
||||
mappings = await prisma_client.db.litellm_jwtkeymapping.find_many(
|
||||
mappings = await JWTKeyMappingRepository(prisma_client).table.find_many(
|
||||
skip=skip,
|
||||
take=size,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
total_count = await prisma_client.db.litellm_jwtkeymapping.count()
|
||||
total_count = await JWTKeyMappingRepository(prisma_client).table.count()
|
||||
return {
|
||||
"mappings": [_to_response(m) for m in mappings],
|
||||
"total_count": total_count,
|
||||
@@ -245,7 +246,7 @@ async def info_jwt_key_mapping(
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique(
|
||||
mapping = await JWTKeyMappingRepository(prisma_client).table.find_unique(
|
||||
where={"id": id}
|
||||
)
|
||||
if mapping is None:
|
||||
|
||||
@@ -28,7 +28,6 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.constants import (
|
||||
LENGTH_OF_LITELLM_GENERATED_KEY,
|
||||
LITELLM_PROXY_ADMIN_NAME,
|
||||
@@ -58,6 +57,7 @@ from litellm.proxy.common_utils.callback_utils import (
|
||||
)
|
||||
from litellm.proxy.common_utils.rbac_utils import check_org_admin_can_generate_keys
|
||||
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||
from litellm.proxy.management_endpoints.common_utils import (
|
||||
_check_passthrough_routes_caller_permission,
|
||||
@@ -91,6 +91,19 @@ from litellm.proxy.utils import (
|
||||
handle_exception_on_proxy,
|
||||
is_valid_api_key,
|
||||
)
|
||||
from litellm.repositories.budget_repository import BudgetRepository
|
||||
from litellm.repositories.config_repository import ConfigRepository
|
||||
from litellm.repositories.credentials_repository import CredentialsRepository
|
||||
from litellm.repositories.model_repository import ModelRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
DeletedVerificationTokenRepository,
|
||||
DeprecatedVerificationTokenRepository,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.router import Router
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.proxy.management_endpoints.key_management_endpoints import (
|
||||
@@ -582,7 +595,7 @@ async def validate_team_id_used_in_service_account_request(
|
||||
)
|
||||
|
||||
# check if team_id exists in the database
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id},
|
||||
)
|
||||
if team is None:
|
||||
@@ -774,7 +787,7 @@ async def _common_key_generation_helper( # noqa: PLR0915
|
||||
)
|
||||
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
|
||||
|
||||
_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
_budget = await BudgetRepository(prisma_client).table.create(
|
||||
data={
|
||||
**new_budget, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
@@ -1144,7 +1157,7 @@ async def _check_team_key_limits(
|
||||
# calculate allocated tpm/rpm limit
|
||||
# check if specified tpm/rpm limit is greater than allocated tpm/rpm limit
|
||||
|
||||
keys = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
keys = await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where={"team_id": team_table.team_id},
|
||||
)
|
||||
# Exclude the key being updated to avoid double-counting its limits.
|
||||
@@ -1294,7 +1307,7 @@ async def _validate_caller_can_assign_key_org(
|
||||
detail="Cannot assign a key to an organization without a user_id on the caller's token",
|
||||
)
|
||||
|
||||
user_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_row = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
@@ -1339,7 +1352,7 @@ async def _check_org_key_limits(
|
||||
# get all organization keys
|
||||
# calculate allocated tpm/rpm limit
|
||||
# check if specified tpm/rpm limit is greater than allocated tpm/rpm limit
|
||||
keys = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
keys = await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where={"organization_id": org_table.organization_id},
|
||||
)
|
||||
# Exclude the key being updated to avoid double-counting its limits.
|
||||
@@ -1968,9 +1981,9 @@ async def _get_and_validate_existing_key(
|
||||
|
||||
hashed_token = _hash_token_if_needed(token=token)
|
||||
|
||||
existing_key_row = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed_token}
|
||||
)
|
||||
existing_key_row = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"token": hashed_token})
|
||||
|
||||
if existing_key_row is None:
|
||||
raise ProxyException(
|
||||
@@ -2869,7 +2882,9 @@ async def bulk_update_team_keys(
|
||||
# `blocked` is Boolean? with no default; `/key/generate` writes NULL. Prisma's `NOT`
|
||||
# excludes NULLs, so explicitly OR `false` with `null` to include them.
|
||||
now = datetime.now(timezone.utc)
|
||||
existing_keys = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
existing_keys = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where={
|
||||
"team_id": data.team_id,
|
||||
"AND": [
|
||||
@@ -2907,7 +2922,9 @@ async def bulk_update_team_keys(
|
||||
seen_hashes.add(h)
|
||||
requested_tokens.append(k)
|
||||
hashed_key_ids.append(h)
|
||||
existing_keys = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
existing_keys = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where={"team_id": data.team_id, "token": {"in": hashed_key_ids}}
|
||||
)
|
||||
|
||||
@@ -3232,7 +3249,9 @@ async def info_key_fn_v2(
|
||||
# Resolve key_aliases to tokens so we never pass token=None (unbounded query)
|
||||
tokens_to_query = list(data.keys) if data.keys else []
|
||||
if data.key_aliases:
|
||||
alias_rows = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
alias_rows = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where={"key_alias": {"in": data.key_aliases}},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
@@ -3311,7 +3330,7 @@ async def info_key_fn(
|
||||
hashed_key: Optional[str] = key
|
||||
if key is not None:
|
||||
hashed_key = _hash_token_if_needed(token=key)
|
||||
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
key_info = await VerificationTokenRepository(prisma_client).table.find_unique(
|
||||
where={"token": hashed_key}, # type: ignore
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
@@ -3851,7 +3870,7 @@ async def delete_verification_tokens(
|
||||
if prisma_client:
|
||||
tokens = [_hash_token_if_needed(token=key) for key in tokens]
|
||||
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
|
||||
await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where={"token": {"in": tokens}}
|
||||
)
|
||||
)
|
||||
@@ -3989,7 +4008,9 @@ async def _save_deleted_verification_token_records(
|
||||
"""Save deleted verification token records to the database."""
|
||||
if not records:
|
||||
return
|
||||
await prisma_client.db.litellm_deletedverificationtoken.create_many(data=records)
|
||||
await DeletedVerificationTokenRepository(prisma_client).table.create_many(
|
||||
data=records
|
||||
)
|
||||
|
||||
|
||||
async def _persist_deleted_verification_tokens(
|
||||
@@ -4017,9 +4038,9 @@ async def delete_key_aliases(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
|
||||
_keys_being_deleted = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"key_alias": {"in": key_aliases}}
|
||||
)
|
||||
_keys_being_deleted = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_many(where={"key_alias": {"in": key_aliases}})
|
||||
|
||||
tokens = [key.token for key in _keys_being_deleted]
|
||||
return await delete_verification_tokens(
|
||||
@@ -4054,9 +4075,7 @@ async def _rotate_master_key( # noqa: PLR0915
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
|
||||
try:
|
||||
models: Optional[List] = (
|
||||
await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
)
|
||||
models: Optional[List] = await ModelRepository(prisma_client).table.find_many()
|
||||
except Exception:
|
||||
models = None
|
||||
# 2. process model table
|
||||
@@ -4088,7 +4107,7 @@ async def _rotate_master_key( # noqa: PLR0915
|
||||
)
|
||||
# 3. process config table
|
||||
try:
|
||||
config = await prisma_client.db.litellm_config.find_many()
|
||||
config = await ConfigRepository(prisma_client).table.find_many()
|
||||
except Exception:
|
||||
config = None
|
||||
|
||||
@@ -4109,7 +4128,7 @@ async def _rotate_master_key( # noqa: PLR0915
|
||||
)
|
||||
|
||||
if encrypted_env_vars:
|
||||
await prisma_client.db.litellm_config.update(
|
||||
await ConfigRepository(prisma_client).table.update(
|
||||
where={"param_name": "environment_variables"},
|
||||
data={"param_value": prisma.Json(encrypted_env_vars)}, # type: ignore[attr-defined]
|
||||
)
|
||||
@@ -4148,7 +4167,7 @@ async def _rotate_master_key( # noqa: PLR0915
|
||||
|
||||
# 5. process credentials table
|
||||
try:
|
||||
credentials = await prisma_client.db.litellm_credentialstable.find_many()
|
||||
credentials = await CredentialsRepository(prisma_client).table.find_many()
|
||||
except Exception:
|
||||
credentials = None
|
||||
if credentials:
|
||||
@@ -4171,7 +4190,7 @@ async def _rotate_master_key( # noqa: PLR0915
|
||||
_cred_data["credential_info"] = prisma.Json( # type: ignore[attr-defined]
|
||||
_cred_data["credential_info"]
|
||||
)
|
||||
await prisma_client.db.litellm_credentialstable.update(
|
||||
await CredentialsRepository(prisma_client).table.update(
|
||||
where={"credential_name": cred.credential_name},
|
||||
data={
|
||||
**_cred_data,
|
||||
@@ -4243,7 +4262,7 @@ async def _insert_deprecated_key(
|
||||
|
||||
try:
|
||||
revoke_at = datetime.now(timezone.utc) + timedelta(seconds=grace_seconds)
|
||||
await prisma_client.db.litellm_deprecatedverificationtoken.upsert(
|
||||
await DeprecatedVerificationTokenRepository(prisma_client).table.upsert(
|
||||
where={"token": old_token_hash},
|
||||
data={
|
||||
"create": {
|
||||
@@ -4335,7 +4354,7 @@ async def _execute_virtual_key_regeneration(
|
||||
grace_period=data.grace_period if data else None,
|
||||
)
|
||||
|
||||
updated_token = await prisma_client.db.litellm_verificationtoken.update(
|
||||
updated_token = await VerificationTokenRepository(prisma_client).table.update(
|
||||
where={"token": hashed_api_key},
|
||||
data=update_data, # type: ignore
|
||||
)
|
||||
@@ -4530,7 +4549,7 @@ async def regenerate_key_fn( # noqa: PLR0915
|
||||
else:
|
||||
hashed_api_key = hash_token(key)
|
||||
|
||||
_key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
_key_in_db = await VerificationTokenRepository(prisma_client).table.find_unique(
|
||||
where={"token": hashed_api_key},
|
||||
)
|
||||
if _key_in_db is None:
|
||||
@@ -4719,7 +4738,7 @@ async def reset_key_spend_fn(
|
||||
else:
|
||||
hashed_api_key = hash_token(key)
|
||||
|
||||
_key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
_key_in_db = await VerificationTokenRepository(prisma_client).table.find_unique(
|
||||
where={"token": hashed_api_key},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
@@ -4739,7 +4758,7 @@ async def reset_key_spend_fn(
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
updated_key = await prisma_client.db.litellm_verificationtoken.update(
|
||||
updated_key = await VerificationTokenRepository(prisma_client).table.update(
|
||||
where={"token": hashed_api_key},
|
||||
data={"spend": reset_to},
|
||||
)
|
||||
@@ -4792,11 +4811,11 @@ async def validate_key_list_check(
|
||||
param="user_id",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
complete_user_info_db_obj: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
complete_user_info_db_obj: Optional[BaseModel] = await UserRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
|
||||
if complete_user_info_db_obj is None:
|
||||
@@ -4846,7 +4865,9 @@ async def validate_key_list_check(
|
||||
|
||||
if key_hash:
|
||||
try:
|
||||
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
key_info = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={"token": key_hash},
|
||||
)
|
||||
except Exception:
|
||||
@@ -4879,11 +4900,9 @@ async def _fetch_user_team_objects(
|
||||
if complete_user_info is None or not complete_user_info.teams:
|
||||
return []
|
||||
|
||||
teams: Optional[List[BaseModel]] = (
|
||||
await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={"team_id": {"in": complete_user_info.teams}}
|
||||
)
|
||||
)
|
||||
teams: Optional[List[BaseModel]] = await TeamRepository(
|
||||
prisma_client
|
||||
).table.find_many(where={"team_id": {"in": complete_user_info.teams}})
|
||||
if teams is None:
|
||||
return []
|
||||
|
||||
@@ -5160,7 +5179,7 @@ async def _apply_non_admin_alias_scope(
|
||||
# Look up the user's teams from the user table
|
||||
user_teams: List[str] = []
|
||||
if user_api_key_dict.user_id:
|
||||
user_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_row = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id}
|
||||
)
|
||||
if user_row is not None:
|
||||
@@ -5548,7 +5567,7 @@ async def _list_key_helper(
|
||||
|
||||
# Fetch keys with pagination
|
||||
if use_deleted_table:
|
||||
keys = await prisma_client.db.litellm_deletedverificationtoken.find_many(
|
||||
keys = await DeletedVerificationTokenRepository(prisma_client).table.find_many(
|
||||
where=where, # type: ignore
|
||||
skip=skip, # type: ignore
|
||||
take=size, # type: ignore
|
||||
@@ -5562,7 +5581,7 @@ async def _list_key_helper(
|
||||
),
|
||||
)
|
||||
else:
|
||||
keys = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
keys = await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where=where, # type: ignore
|
||||
skip=skip, # type: ignore
|
||||
take=size, # type: ignore
|
||||
@@ -5581,11 +5600,13 @@ async def _list_key_helper(
|
||||
|
||||
# Get total count of keys
|
||||
if use_deleted_table:
|
||||
total_count = await prisma_client.db.litellm_deletedverificationtoken.count(
|
||||
total_count = await DeletedVerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.count(
|
||||
where=where # type: ignore
|
||||
)
|
||||
else:
|
||||
total_count = await prisma_client.db.litellm_verificationtoken.count(
|
||||
total_count = await VerificationTokenRepository(prisma_client).table.count(
|
||||
where=where # type: ignore
|
||||
)
|
||||
|
||||
@@ -5601,7 +5622,7 @@ async def _list_key_helper(
|
||||
created_by_ids = [key.created_by for key in keys if key.created_by]
|
||||
all_ids = list(set(user_ids + created_by_ids)) # Remove duplicates
|
||||
if all_ids:
|
||||
users = await prisma_client.db.litellm_usertable.find_many(
|
||||
users = await UserRepository(prisma_client).table.find_many(
|
||||
where={"user_id": {"in": all_ids}}
|
||||
)
|
||||
user_map = {user.user_id: user for user in users}
|
||||
@@ -5688,7 +5709,7 @@ async def _check_key_admin_access(
|
||||
return
|
||||
|
||||
# Look up the target key to find its team
|
||||
target_key_row = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
target_key_row = await VerificationTokenRepository(prisma_client).table.find_unique(
|
||||
where={"token": hashed_token}
|
||||
)
|
||||
if target_key_row is None:
|
||||
@@ -5755,6 +5776,9 @@ async def block_key(
|
||||
|
||||
Note: This is an admin-only endpoint. Only proxy admins, team admins, or org admins can block keys.
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
get_audit_log_changed_by,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
create_audit_log_for_update,
|
||||
hash_token,
|
||||
@@ -5763,9 +5787,6 @@ async def block_key(
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
get_audit_log_changed_by,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value))
|
||||
@@ -5792,9 +5813,9 @@ async def block_key(
|
||||
)
|
||||
|
||||
# Check if the key exists before trying to block it
|
||||
existing_record = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed_token}
|
||||
)
|
||||
existing_record = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"token": hashed_token})
|
||||
if existing_record is None:
|
||||
raise ProxyException(
|
||||
message="Key not found.",
|
||||
@@ -5824,7 +5845,7 @@ async def block_key(
|
||||
)
|
||||
)
|
||||
|
||||
record = await prisma_client.db.litellm_verificationtoken.update(
|
||||
record = await VerificationTokenRepository(prisma_client).table.update(
|
||||
where={"token": hashed_token}, data={"blocked": True} # type: ignore
|
||||
)
|
||||
|
||||
@@ -5869,6 +5890,9 @@ async def unblock_key(
|
||||
|
||||
Note: This is an admin-only endpoint. Only proxy admins, team admins, or org admins can unblock keys.
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
get_audit_log_changed_by,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
create_audit_log_for_update,
|
||||
hash_token,
|
||||
@@ -5877,9 +5901,6 @@ async def unblock_key(
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
get_audit_log_changed_by,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value))
|
||||
@@ -5906,9 +5927,9 @@ async def unblock_key(
|
||||
)
|
||||
|
||||
# Check if the key exists before trying to unblock it
|
||||
existing_record = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed_token}
|
||||
)
|
||||
existing_record = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"token": hashed_token})
|
||||
if existing_record is None:
|
||||
raise ProxyException(
|
||||
message="Key not found.",
|
||||
@@ -5938,7 +5959,7 @@ async def unblock_key(
|
||||
)
|
||||
)
|
||||
|
||||
record = await prisma_client.db.litellm_verificationtoken.update(
|
||||
record = await VerificationTokenRepository(prisma_client).table.update(
|
||||
where={"token": hashed_token}, data={"blocked": False} # type: ignore
|
||||
)
|
||||
|
||||
@@ -6202,9 +6223,9 @@ async def _enforce_unique_key_alias(
|
||||
# Exclude the current key from the uniqueness check
|
||||
where_clause["NOT"] = {"token": existing_key_token}
|
||||
|
||||
existing_key = await prisma_client.db.litellm_verificationtoken.find_first(
|
||||
where=where_clause
|
||||
)
|
||||
existing_key = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_first(where=where_clause)
|
||||
if existing_key is not None:
|
||||
raise ProxyException(
|
||||
message=f"Key with alias '{key_alias}' already exists. Unique key aliases across all keys are required.",
|
||||
|
||||
@@ -60,6 +60,10 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import get_audit_log_changed_by
|
||||
from litellm.repositories.table_repositories import (
|
||||
MCPServerRepository,
|
||||
MCPUserCredentialsRepository,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/v1/mcp", tags=["mcp"])
|
||||
|
||||
@@ -761,7 +765,7 @@ if MCP_AVAILABLE:
|
||||
# Get from DB
|
||||
if prisma_client is not None:
|
||||
try:
|
||||
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
|
||||
mcp_servers = await MCPServerRepository(prisma_client).table.find_many()
|
||||
for server in mcp_servers:
|
||||
if (
|
||||
hasattr(server, "mcp_access_groups")
|
||||
@@ -998,10 +1002,10 @@ if MCP_AVAILABLE:
|
||||
if getattr(s, "is_byok", False)
|
||||
]
|
||||
if byok_server_ids:
|
||||
cred_rows = (
|
||||
await _byok_prisma_client.db.litellm_mcpusercredentials.find_many(
|
||||
where={"user_id": user_id, "server_id": {"in": byok_server_ids}}
|
||||
)
|
||||
cred_rows = await MCPUserCredentialsRepository(
|
||||
_byok_prisma_client
|
||||
).table.find_many(
|
||||
where={"user_id": user_id, "server_id": {"in": byok_server_ids}}
|
||||
)
|
||||
cred_set = {r.server_id for r in cred_rows}
|
||||
for server in redacted_mcp_servers:
|
||||
|
||||
@@ -19,6 +19,7 @@ from litellm.proxy.management_endpoints.model_management_endpoints import (
|
||||
clear_cache,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.model_repository import ModelRepository
|
||||
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
|
||||
AccessGroupInfo,
|
||||
DeleteModelGroupResponse,
|
||||
@@ -95,7 +96,7 @@ async def update_deployments_with_access_group(
|
||||
verbose_proxy_logger.debug(f"Updating deployments for model_name: {model_name}")
|
||||
|
||||
# Get all deployments with this model_name
|
||||
deployments = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
deployments = await ModelRepository(prisma_client).table.find_many(
|
||||
where={"model_name": model_name}
|
||||
)
|
||||
|
||||
@@ -124,7 +125,7 @@ async def update_deployments_with_access_group(
|
||||
|
||||
# Only update in DB if modified
|
||||
if was_modified:
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
await ModelRepository(prisma_client).table.update(
|
||||
where={"model_id": deployment.model_id},
|
||||
data={"model_info": json.dumps(updated_model_info)},
|
||||
)
|
||||
@@ -152,7 +153,7 @@ async def update_specific_deployments_with_access_group(
|
||||
models_updated = 0
|
||||
for model_id in model_ids:
|
||||
verbose_proxy_logger.debug(f"Updating specific deployment model_id: {model_id}")
|
||||
deployment = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
deployment = await ModelRepository(prisma_client).table.find_unique(
|
||||
where={"model_id": model_id}
|
||||
)
|
||||
if deployment is None:
|
||||
@@ -168,7 +169,7 @@ async def update_specific_deployments_with_access_group(
|
||||
access_group=access_group,
|
||||
)
|
||||
if was_modified:
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
await ModelRepository(prisma_client).table.update(
|
||||
where={"model_id": model_id},
|
||||
data={"model_info": json.dumps(updated_model_info)},
|
||||
)
|
||||
@@ -215,7 +216,7 @@ async def get_all_access_groups_from_db(
|
||||
Dict[str, AccessGroupInfo]: Dictionary mapping access_group name to info
|
||||
"""
|
||||
# Get all deployments
|
||||
deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
deployments = await ModelRepository(prisma_client).table.find_many()
|
||||
|
||||
# Build access group map
|
||||
access_group_map: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -604,7 +605,7 @@ async def update_access_group(
|
||||
|
||||
try:
|
||||
# Step 1: Remove access group from ALL DB deployments (skip config models)
|
||||
all_deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
all_deployments = await ModelRepository(prisma_client).table.find_many()
|
||||
|
||||
for deployment in all_deployments:
|
||||
model_info = deployment.model_info or {}
|
||||
@@ -615,7 +616,7 @@ async def update_access_group(
|
||||
)
|
||||
|
||||
if was_modified:
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
await ModelRepository(prisma_client).table.update(
|
||||
where={"model_id": deployment.model_id},
|
||||
data={"model_info": json.dumps(updated_model_info)},
|
||||
)
|
||||
@@ -722,7 +723,7 @@ async def delete_access_group(
|
||||
|
||||
try:
|
||||
# Remove access group from all DB deployments (skip config models)
|
||||
all_deployments = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
all_deployments = await ModelRepository(prisma_client).table.find_many()
|
||||
models_updated = 0
|
||||
|
||||
for deployment in all_deployments:
|
||||
@@ -734,7 +735,7 @@ async def delete_access_group(
|
||||
)
|
||||
|
||||
if was_modified:
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
await ModelRepository(prisma_client).table.update(
|
||||
where={"model_id": deployment.model_id},
|
||||
data={"model_info": json.dumps(updated_model_info)},
|
||||
)
|
||||
|
||||
@@ -48,6 +48,9 @@ from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import create_object_audit_log
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.model_repository import ModelRepository
|
||||
from litellm.repositories.table_repositories import ModelTableRepository
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
|
||||
UpdateUsefulLinksRequest,
|
||||
)
|
||||
@@ -86,7 +89,7 @@ async def get_db_model(
|
||||
) -> Optional[Deployment]:
|
||||
db_model = cast(
|
||||
Optional[BaseModel],
|
||||
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
await ModelRepository(prisma_client).table.find_unique(
|
||||
where={"model_id": model_id}
|
||||
),
|
||||
)
|
||||
@@ -290,7 +293,7 @@ async def patch_model(
|
||||
update_data["updated_at"] = cast(str, get_utc_datetime())
|
||||
|
||||
# Perform partial update
|
||||
updated_model = await prisma_client.db.litellm_proxymodeltable.update(
|
||||
updated_model = await ModelRepository(prisma_client).table.update(
|
||||
where={"model_id": model_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -362,7 +365,7 @@ async def _add_model_to_db(
|
||||
if model_params.model_info.id is not None:
|
||||
_data["model_id"] = model_params.model_info.id
|
||||
if should_create_model_in_db:
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
||||
model_response = await ModelRepository(prisma_client).table.create(
|
||||
data=_data # type: ignore
|
||||
)
|
||||
else:
|
||||
@@ -571,7 +574,7 @@ async def _get_team_deployments(
|
||||
team_id in model_info with Python-side filtering.
|
||||
"""
|
||||
prefix = f"model_name_{team_id}_"
|
||||
response = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
response = await ModelRepository(prisma_client).table.find_many(
|
||||
where={
|
||||
"model_name": {"startswith": prefix},
|
||||
}
|
||||
@@ -828,7 +831,7 @@ class ModelManagementAuthChecks:
|
||||
detail={"error": CommonProxyErrors.not_premium_user.value},
|
||||
)
|
||||
|
||||
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
_existing_team_row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": model_params.model_info.team_id}
|
||||
)
|
||||
|
||||
@@ -863,7 +866,7 @@ class ModelManagementAuthChecks:
|
||||
model_params.model_info is not None
|
||||
and model_params.model_info.team_id is not None
|
||||
):
|
||||
team_obj_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team_obj_row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": model_params.model_info.team_id}
|
||||
)
|
||||
if team_obj_row is None:
|
||||
@@ -937,7 +940,7 @@ async def delete_model(
|
||||
},
|
||||
)
|
||||
|
||||
model_in_db = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
model_in_db = await ModelRepository(prisma_client).table.find_unique(
|
||||
where={"model_id": model_info.id}
|
||||
)
|
||||
if model_in_db is None:
|
||||
@@ -961,7 +964,7 @@ async def delete_model(
|
||||
- store keys separately
|
||||
"""
|
||||
# encrypt litellm params #
|
||||
result = await prisma_client.db.litellm_proxymodeltable.delete(
|
||||
result = await ModelRepository(prisma_client).table.delete(
|
||||
where={"model_id": model_info.id}
|
||||
)
|
||||
|
||||
@@ -1039,7 +1042,7 @@ async def delete_team_model_alias(
|
||||
Returns:
|
||||
- List of team id + model alias pairs that were removed
|
||||
"""
|
||||
team_model_aliases = await prisma_client.db.litellm_modeltable.find_many(
|
||||
team_model_aliases = await ModelTableRepository(prisma_client).table.find_many(
|
||||
include={"team": True}
|
||||
)
|
||||
tasks = []
|
||||
@@ -1056,7 +1059,7 @@ async def delete_team_model_alias(
|
||||
removed_model_aliases.append((team_model_alias.team.team_id, key))
|
||||
del model_aliases[key]
|
||||
tasks.append(
|
||||
prisma_client.db.litellm_modeltable.update(
|
||||
ModelTableRepository(prisma_client).table.update(
|
||||
where={"id": id},
|
||||
data={"model_aliases": json.dumps(model_aliases)},
|
||||
)
|
||||
@@ -1275,11 +1278,9 @@ async def update_model(
|
||||
if _model_id is None:
|
||||
raise Exception("model_info.id not provided")
|
||||
|
||||
_existing_litellm_params = (
|
||||
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": _model_id}
|
||||
)
|
||||
)
|
||||
_existing_litellm_params = await ModelRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"model_id": _model_id})
|
||||
|
||||
if _existing_litellm_params is None:
|
||||
if (
|
||||
@@ -1340,7 +1341,7 @@ async def update_model(
|
||||
"litellm_params": json.dumps(merged_dictionary), # type: ignore
|
||||
"updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
||||
}
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.update(
|
||||
model_response = await ModelRepository(prisma_client).table.update(
|
||||
where={"model_id": _model_id},
|
||||
data=_data, # type: ignore
|
||||
)
|
||||
|
||||
@@ -40,6 +40,15 @@ from litellm.proxy.management_helpers.utils import (
|
||||
management_endpoint_wrapper,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.budget_repository import BudgetRepository
|
||||
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
|
||||
from litellm.repositories.organization_repository import OrganizationRepository
|
||||
from litellm.repositories.table_repositories import OrganizationMembershipRepository
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
@@ -245,7 +254,7 @@ async def new_organization(
|
||||
|
||||
if user_api_key_dict.user_id is not None:
|
||||
try:
|
||||
user_object = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_object = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id}
|
||||
)
|
||||
user_object_correct_type = LiteLLM_UserTable(**user_object.model_dump())
|
||||
@@ -267,7 +276,7 @@ async def new_organization(
|
||||
|
||||
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
|
||||
|
||||
_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
_budget = await BudgetRepository(prisma_client).table.create(
|
||||
data={
|
||||
**new_budget, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
@@ -323,7 +332,7 @@ async def new_organization(
|
||||
verbose_proxy_logger.info(
|
||||
f"new_organization_row: {json.dumps(new_organization_row, indent=2)}"
|
||||
)
|
||||
response = await prisma_client.db.litellm_organizationtable.create(
|
||||
response = await OrganizationRepository(prisma_client).table.create(
|
||||
data={
|
||||
**new_organization_row, # type: ignore
|
||||
},
|
||||
@@ -372,9 +381,9 @@ async def get_organization_daily_activity(
|
||||
|
||||
# Restrict non-proxy-admins to only organizations where they are org_admin
|
||||
if not _user_has_admin_view(user_api_key_dict):
|
||||
memberships = await prisma_client.db.litellm_organizationmembership.find_many(
|
||||
where={"user_id": user_api_key_dict.user_id}
|
||||
)
|
||||
memberships = await OrganizationMembershipRepository(
|
||||
prisma_client
|
||||
).table.find_many(where={"user_id": user_api_key_dict.user_id})
|
||||
admin_org_ids = [
|
||||
m.organization_id
|
||||
for m in memberships
|
||||
@@ -400,7 +409,7 @@ async def get_organization_daily_activity(
|
||||
where_condition = {}
|
||||
if org_ids_list:
|
||||
where_condition["organization_id"] = {"in": list(org_ids_list)}
|
||||
org_aliases = await prisma_client.db.litellm_organizationtable.find_many(
|
||||
org_aliases = await OrganizationRepository(prisma_client).table.find_many(
|
||||
where=where_condition
|
||||
)
|
||||
org_alias_metadata = {
|
||||
@@ -439,10 +448,10 @@ async def _set_object_permission(
|
||||
return None
|
||||
|
||||
if data.object_permission is not None:
|
||||
created_object_permission = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
data=data.object_permission.model_dump(exclude_none=True),
|
||||
)
|
||||
created_object_permission = await ObjectPermissionRepository(
|
||||
prisma_client
|
||||
).table.create(
|
||||
data=data.object_permission.model_dump(exclude_none=True),
|
||||
)
|
||||
del data.object_permission
|
||||
return created_object_permission.object_permission_id
|
||||
@@ -525,10 +534,10 @@ async def update_organization(
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
existing_organization_row = (
|
||||
await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
where={"organization_id": data.organization_id},
|
||||
)
|
||||
existing_organization_row = await OrganizationRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={"organization_id": data.organization_id},
|
||||
)
|
||||
|
||||
if existing_organization_row is None:
|
||||
@@ -574,7 +583,7 @@ async def update_organization(
|
||||
for field in LiteLLM_BudgetTable.model_fields.keys():
|
||||
updated_organization_row.pop(field, None)
|
||||
|
||||
response = await prisma_client.db.litellm_organizationtable.update(
|
||||
response = await OrganizationRepository(prisma_client).table.update(
|
||||
where={"organization_id": data.organization_id},
|
||||
data=updated_organization_row,
|
||||
include={"members": True, "teams": True, "litellm_budget_table": True},
|
||||
@@ -644,19 +653,19 @@ async def delete_organization(
|
||||
deleted_orgs = []
|
||||
for organization_id in data.organization_ids:
|
||||
# delete all teams in the organization
|
||||
await prisma_client.db.litellm_teamtable.delete_many(
|
||||
await TeamRepository(prisma_client).table.delete_many(
|
||||
where={"organization_id": organization_id}
|
||||
)
|
||||
# delete all members in the organization
|
||||
await prisma_client.db.litellm_organizationmembership.delete_many(
|
||||
await OrganizationMembershipRepository(prisma_client).table.delete_many(
|
||||
where={"organization_id": organization_id}
|
||||
)
|
||||
# delete all keys in the organization
|
||||
await prisma_client.db.litellm_verificationtoken.delete_many(
|
||||
await VerificationTokenRepository(prisma_client).table.delete_many(
|
||||
where={"organization_id": organization_id}
|
||||
)
|
||||
# delete the organization
|
||||
deleted_org = await prisma_client.db.litellm_organizationtable.delete(
|
||||
deleted_org = await OrganizationRepository(prisma_client).table.delete(
|
||||
where={"organization_id": organization_id},
|
||||
include={"members": True, "teams": True, "litellm_budget_table": True},
|
||||
)
|
||||
@@ -732,17 +741,15 @@ async def list_organization(
|
||||
|
||||
# if proxy admin or admin viewer - get all orgs (with optional filters)
|
||||
if _user_has_admin_view(user_api_key_dict):
|
||||
response = await prisma_client.db.litellm_organizationtable.find_many(
|
||||
response = await OrganizationRepository(prisma_client).table.find_many(
|
||||
where=where_conditions if where_conditions else None,
|
||||
include={"litellm_budget_table": True, "members": True, "teams": True},
|
||||
)
|
||||
# if internal user - get orgs they are a member of (with optional filters)
|
||||
else:
|
||||
org_memberships = (
|
||||
await prisma_client.db.litellm_organizationmembership.find_many(
|
||||
where={"user_id": user_api_key_dict.user_id}
|
||||
)
|
||||
)
|
||||
org_memberships = await OrganizationMembershipRepository(
|
||||
prisma_client
|
||||
).table.find_many(where={"user_id": user_api_key_dict.user_id})
|
||||
membership_org_ids = [
|
||||
membership.organization_id for membership in org_memberships
|
||||
]
|
||||
@@ -756,20 +763,20 @@ async def list_organization(
|
||||
response = []
|
||||
else:
|
||||
where_conditions["organization_id"] = org_id
|
||||
response = (
|
||||
await prisma_client.db.litellm_organizationtable.find_many(
|
||||
where=where_conditions,
|
||||
include={
|
||||
"litellm_budget_table": True,
|
||||
"members": True,
|
||||
"teams": True,
|
||||
},
|
||||
)
|
||||
response = await OrganizationRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where=where_conditions,
|
||||
include={
|
||||
"litellm_budget_table": True,
|
||||
"members": True,
|
||||
"teams": True,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Filter by membership and any additional filters
|
||||
where_conditions["organization_id"] = {"in": membership_org_ids}
|
||||
response = await prisma_client.db.litellm_organizationtable.find_many(
|
||||
response = await OrganizationRepository(prisma_client).table.find_many(
|
||||
where=where_conditions,
|
||||
include={
|
||||
"litellm_budget_table": True,
|
||||
@@ -809,20 +816,20 @@ async def info_organization(
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
response: Optional[LiteLLM_OrganizationTableWithMembers] = (
|
||||
await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
where={"organization_id": organization_id},
|
||||
include={
|
||||
"litellm_budget_table": True,
|
||||
"members": {
|
||||
"include": {
|
||||
"user": True,
|
||||
}
|
||||
},
|
||||
"teams": True,
|
||||
"object_permission": True,
|
||||
response: Optional[
|
||||
LiteLLM_OrganizationTableWithMembers
|
||||
] = await OrganizationRepository(prisma_client).table.find_unique(
|
||||
where={"organization_id": organization_id},
|
||||
include={
|
||||
"litellm_budget_table": True,
|
||||
"members": {
|
||||
"include": {
|
||||
"user": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
"teams": True,
|
||||
"object_permission": True,
|
||||
},
|
||||
)
|
||||
|
||||
if response is None:
|
||||
@@ -868,7 +875,7 @@ async def deprecated_info_organization(
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_organizationtable.find_many(
|
||||
response = await OrganizationRepository(prisma_client).table.find_many(
|
||||
where={"organization_id": {"in": data.organizations}},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
@@ -945,11 +952,9 @@ async def organization_member_add(
|
||||
)
|
||||
|
||||
# Check if organization exists
|
||||
existing_organization_row = (
|
||||
await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
where={"organization_id": data.organization_id}
|
||||
)
|
||||
)
|
||||
existing_organization_row = await OrganizationRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"organization_id": data.organization_id})
|
||||
if existing_organization_row is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@@ -1012,11 +1017,9 @@ async def find_member_if_email(
|
||||
"""
|
||||
|
||||
try:
|
||||
existing_user_email_row: BaseModel = (
|
||||
await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_email": user_email}
|
||||
)
|
||||
)
|
||||
existing_user_email_row: BaseModel = await UserRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"user_email": user_email})
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -1064,11 +1067,9 @@ async def organization_member_update(
|
||||
)
|
||||
|
||||
# Check if organization exists
|
||||
existing_organization_row = (
|
||||
await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
where={"organization_id": data.organization_id}
|
||||
)
|
||||
)
|
||||
existing_organization_row = await OrganizationRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"organization_id": data.organization_id})
|
||||
if existing_organization_row is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -1085,15 +1086,15 @@ async def organization_member_update(
|
||||
data.user_id = existing_user_email_row.user_id
|
||||
|
||||
try:
|
||||
existing_organization_membership = (
|
||||
await prisma_client.db.litellm_organizationmembership.find_unique(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
"organization_id": data.organization_id,
|
||||
}
|
||||
existing_organization_membership = await OrganizationMembershipRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
"organization_id": data.organization_id,
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
@@ -1114,7 +1115,7 @@ async def organization_member_update(
|
||||
# org-scoped operations. An org-admin of any org could otherwise
|
||||
# alter a PROXY_ADMIN user's per-org role, which has downstream
|
||||
# effects on admin UI filtering and scope derivation.
|
||||
target_user_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
target_user_row = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": data.user_id}
|
||||
)
|
||||
if target_user_row is not None and getattr(
|
||||
@@ -1136,7 +1137,7 @@ async def organization_member_update(
|
||||
|
||||
# Update member role
|
||||
if data.role is not None:
|
||||
await prisma_client.db.litellm_organizationmembership.update(
|
||||
await OrganizationMembershipRepository(prisma_client).table.update(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
@@ -1165,7 +1166,7 @@ async def organization_member_update(
|
||||
)
|
||||
|
||||
# update organization membership with new budget_id
|
||||
await prisma_client.db.litellm_organizationmembership.update(
|
||||
await OrganizationMembershipRepository(prisma_client).table.update(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
@@ -1174,16 +1175,16 @@ async def organization_member_update(
|
||||
},
|
||||
data={"budget_id": budget_id},
|
||||
)
|
||||
final_organization_membership: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_organizationmembership.find_unique(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
"organization_id": data.organization_id,
|
||||
}
|
||||
},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
final_organization_membership: Optional[
|
||||
BaseModel
|
||||
] = await OrganizationMembershipRepository(prisma_client).table.find_unique(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
"organization_id": data.organization_id,
|
||||
}
|
||||
},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
|
||||
if final_organization_membership is None:
|
||||
@@ -1239,7 +1240,9 @@ async def organization_member_delete(
|
||||
)
|
||||
data.user_id = existing_user_email_row.user_id
|
||||
|
||||
member_to_delete = await prisma_client.db.litellm_organizationmembership.delete(
|
||||
member_to_delete = await OrganizationMembershipRepository(
|
||||
prisma_client
|
||||
).table.delete(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
@@ -1273,17 +1276,15 @@ async def add_member_to_organization(
|
||||
existing_user_email_row = None
|
||||
## Check if user exists in LiteLLM_UserTable - user exists - either the user_id or user_email is in LiteLLM_UserTable
|
||||
if member.user_id is not None:
|
||||
existing_user_id_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": member.user_id}
|
||||
)
|
||||
existing_user_id_row = await UserRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"user_id": member.user_id})
|
||||
|
||||
if existing_user_id_row is None and member.user_email is not None:
|
||||
try:
|
||||
existing_user_email_row = (
|
||||
await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_email": member.user_email}
|
||||
)
|
||||
)
|
||||
existing_user_email_row = await UserRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"user_email": member.user_email})
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Potential NON-Existent or Duplicate user email in DB: Error finding a unique instance of user_email={member.user_email} in LiteLLM_UserTable.: {e}"
|
||||
@@ -1326,14 +1327,14 @@ async def add_member_to_organization(
|
||||
)
|
||||
|
||||
# Add user to organization
|
||||
_organization_membership = (
|
||||
await prisma_client.db.litellm_organizationmembership.create(
|
||||
data={
|
||||
"organization_id": organization_id,
|
||||
"user_id": user_object.user_id,
|
||||
"user_role": member.role,
|
||||
}
|
||||
)
|
||||
_organization_membership = await OrganizationMembershipRepository(
|
||||
prisma_client
|
||||
).table.create(
|
||||
data={
|
||||
"organization_id": organization_id,
|
||||
"user_id": user_object.user_id,
|
||||
"user_role": member.role,
|
||||
}
|
||||
)
|
||||
organization_membership = LiteLLM_OrganizationMembershipTable(
|
||||
**_organization_membership.model_dump()
|
||||
|
||||
@@ -6,6 +6,7 @@ from litellm.proxy._types import (
|
||||
Member,
|
||||
NewUserResponse,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.types.proxy.management_endpoints.scim_v2 import *
|
||||
|
||||
|
||||
@@ -29,7 +30,7 @@ class ScimTransformations:
|
||||
# Get user's teams/groups
|
||||
groups = []
|
||||
for team_id in user.teams or []:
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
if team:
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing_extensions import TypedDict
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
|
||||
from litellm._uuid import uuid
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import (
|
||||
@@ -41,6 +40,7 @@ from litellm.proxy._types import (
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import _delete_cache_key_object
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||||
from litellm.proxy.management_endpoints.scim.scim_transformations import (
|
||||
ScimTransformations,
|
||||
@@ -51,6 +51,16 @@ from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
team_member_delete,
|
||||
)
|
||||
from litellm.proxy.utils import _premium_user_check, handle_exception_on_proxy
|
||||
from litellm.repositories.table_repositories import (
|
||||
InvitationLinkRepository,
|
||||
OrganizationMembershipRepository,
|
||||
TeamMembershipRepository,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.types.proxy.management_endpoints.scim_v2 import *
|
||||
|
||||
|
||||
@@ -74,7 +84,7 @@ class UserProvisionerHelpers:
|
||||
if not new_user_request.user_email:
|
||||
return None
|
||||
|
||||
existing_user = await prisma_client.db.litellm_usertable.find_first(
|
||||
existing_user = await UserRepository(prisma_client).table.find_first(
|
||||
where={"user_email": new_user_request.user_email}
|
||||
)
|
||||
|
||||
@@ -82,7 +92,7 @@ class UserProvisionerHelpers:
|
||||
return None
|
||||
|
||||
# Update the user
|
||||
updated_user = await prisma_client.db.litellm_usertable.update(
|
||||
updated_user = await UserRepository(prisma_client).table.update(
|
||||
where={"user_id": existing_user.user_id},
|
||||
data={
|
||||
"user_id": new_user_request.user_id,
|
||||
@@ -139,7 +149,7 @@ async def _check_user_exists(user_id: str):
|
||||
"""Check if user exists and return user, raise 404 if not found."""
|
||||
prisma_client = await _get_prisma_client_or_raise_exception()
|
||||
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
@@ -155,7 +165,7 @@ async def _check_team_exists(team_id: str):
|
||||
"""Check if team exists and return team, raise 404 if not found."""
|
||||
prisma_client = await _get_prisma_client_or_raise_exception()
|
||||
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
@@ -268,7 +278,7 @@ async def _extract_group_member_ids(group: SCIMGroup) -> GroupMemberExtractionRe
|
||||
)
|
||||
|
||||
# Check if user exists
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
@@ -310,7 +320,7 @@ async def _get_team_members_display(member_ids: List[str]) -> List[SCIMMember]:
|
||||
members: List[SCIMMember] = []
|
||||
|
||||
for member_id in member_ids:
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": member_id}
|
||||
)
|
||||
if user:
|
||||
@@ -367,7 +377,7 @@ async def _set_user_keys_blocked(user_id: str, blocked: bool) -> int:
|
||||
# `blocked` is a nullable column with no default, so existing rows
|
||||
# typically hold NULL; treat NULL as "not blocked" since SQL equality
|
||||
# on NULL would otherwise silently skip them.
|
||||
candidates = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
candidates = await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where={
|
||||
"user_id": user_id,
|
||||
"OR": [{"blocked": False}, {"blocked": None}],
|
||||
@@ -375,7 +385,7 @@ async def _set_user_keys_blocked(user_id: str, blocked: bool) -> int:
|
||||
)
|
||||
affected_keys = candidates
|
||||
else:
|
||||
candidates = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
candidates = await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where={"user_id": user_id, "blocked": True},
|
||||
)
|
||||
affected_keys = [k for k in candidates if _key_was_scim_blocked(k.metadata)]
|
||||
@@ -395,7 +405,7 @@ async def _set_user_keys_blocked(user_id: str, blocked: bool) -> int:
|
||||
for k, v in current_metadata.items()
|
||||
if k != SCIM_BLOCKED_METADATA_KEY
|
||||
}
|
||||
await prisma_client.db.litellm_verificationtoken.update(
|
||||
await VerificationTokenRepository(prisma_client).table.update(
|
||||
where={"token": key_row.token},
|
||||
data={"blocked": blocked, "metadata": safe_dumps(new_metadata)},
|
||||
)
|
||||
@@ -423,7 +433,7 @@ async def _delete_rows_referencing_user(prisma_client: Any, *, user_id: str) ->
|
||||
the user delete with an FK constraint violation (e.g.
|
||||
``LiteLLM_InvitationLink_user_id_fkey``).
|
||||
"""
|
||||
await prisma_client.db.litellm_invitationlink.delete_many(
|
||||
await InvitationLinkRepository(prisma_client).table.delete_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"user_id": user_id},
|
||||
@@ -432,10 +442,10 @@ async def _delete_rows_referencing_user(prisma_client: Any, *, user_id: str) ->
|
||||
]
|
||||
}
|
||||
)
|
||||
await prisma_client.db.litellm_organizationmembership.delete_many(
|
||||
await OrganizationMembershipRepository(prisma_client).table.delete_many(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
await prisma_client.db.litellm_teammembership.delete_many(
|
||||
await TeamMembershipRepository(prisma_client).table.delete_many(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
@@ -897,17 +907,17 @@ async def get_users(
|
||||
where_conditions["user_email"] = filter_value
|
||||
|
||||
# Get users from database
|
||||
users: List[LiteLLM_UserTable] = (
|
||||
await prisma_client.db.litellm_usertable.find_many(
|
||||
where=where_conditions,
|
||||
skip=(startIndex - 1),
|
||||
take=count,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
users: List[LiteLLM_UserTable] = await UserRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where=where_conditions,
|
||||
skip=(startIndex - 1),
|
||||
take=count,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = await prisma_client.db.litellm_usertable.count(
|
||||
total_count = await UserRepository(prisma_client).table.count(
|
||||
where=where_conditions
|
||||
)
|
||||
|
||||
@@ -975,7 +985,7 @@ async def create_user(
|
||||
|
||||
# Check if user already exists
|
||||
if user.userName:
|
||||
existing_user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
existing_user = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user.userName}
|
||||
)
|
||||
if existing_user:
|
||||
@@ -1094,7 +1104,7 @@ async def update_user(
|
||||
"metadata": safe_dumps(metadata),
|
||||
}
|
||||
|
||||
updated_user = await prisma_client.db.litellm_usertable.update(
|
||||
updated_user = await UserRepository(prisma_client).table.update(
|
||||
where={"user_id": user_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -1137,7 +1147,7 @@ async def delete_user(
|
||||
teams = []
|
||||
if existing_user.teams:
|
||||
for team_id in existing_user.teams:
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
if team:
|
||||
@@ -1148,7 +1158,7 @@ async def delete_user(
|
||||
current_members = team.members or []
|
||||
if user_id in current_members:
|
||||
new_members = [m for m in current_members if m != user_id]
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": team.team_id}, data={"members": new_members}
|
||||
)
|
||||
|
||||
@@ -1157,7 +1167,7 @@ async def delete_user(
|
||||
await _delete_rows_referencing_user(prisma_client, user_id=user_id)
|
||||
|
||||
# Delete user
|
||||
await prisma_client.db.litellm_usertable.delete(where={"user_id": user_id})
|
||||
await UserRepository(prisma_client).table.delete(where={"user_id": user_id})
|
||||
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
@@ -1413,7 +1423,7 @@ async def patch_user(
|
||||
|
||||
update_data["metadata"] = safe_dumps(update_data["metadata"])
|
||||
|
||||
updated_user = await prisma_client.db.litellm_usertable.update(
|
||||
updated_user = await UserRepository(prisma_client).table.update(
|
||||
where={"user_id": user_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -1465,7 +1475,7 @@ async def get_groups(
|
||||
where_conditions["team_alias"] = team_alias
|
||||
|
||||
# Get teams from database
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where=where_conditions,
|
||||
skip=(startIndex - 1),
|
||||
take=count,
|
||||
@@ -1473,7 +1483,7 @@ async def get_groups(
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = await prisma_client.db.litellm_teamtable.count(
|
||||
total_count = await TeamRepository(prisma_client).table.count(
|
||||
where=where_conditions
|
||||
)
|
||||
|
||||
@@ -1561,7 +1571,7 @@ async def create_group(
|
||||
team_id = group.id or group.externalId or str(uuid.uuid4())
|
||||
|
||||
# Check if team already exists
|
||||
existing_team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
existing_team = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
@@ -1638,7 +1648,7 @@ async def update_group(
|
||||
}
|
||||
|
||||
# Update team in database
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
updated_team = await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": group_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -1683,19 +1693,19 @@ async def delete_group(
|
||||
|
||||
# For each member, remove this team from their teams list
|
||||
for member_id in existing_team.members or []:
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": member_id}
|
||||
)
|
||||
if user:
|
||||
current_teams = user.teams or []
|
||||
if group_id in current_teams:
|
||||
new_teams = [t for t in current_teams if t != group_id]
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
await UserRepository(prisma_client).table.update(
|
||||
where={"user_id": member_id}, data={"teams": new_teams}
|
||||
)
|
||||
|
||||
# Delete team
|
||||
await prisma_client.db.litellm_teamtable.delete(where={"team_id": group_id})
|
||||
await TeamRepository(prisma_client).table.delete(where={"team_id": group_id})
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -1748,7 +1758,7 @@ async def _process_group_patch_operations(
|
||||
detail={"error": "Invalid member: user ID cannot be empty."},
|
||||
)
|
||||
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": member_id}
|
||||
)
|
||||
if user:
|
||||
@@ -1805,7 +1815,7 @@ async def _apply_group_patch_updates(
|
||||
update_data["members"] = list(final_members)
|
||||
|
||||
# Update team in database
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
updated_team = await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": group_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -1877,7 +1887,7 @@ async def patch_group(
|
||||
|
||||
# Refresh team data from database to get the latest state after concurrent updates
|
||||
# This prevents race conditions when multiple PATCH requests come in simultaneously
|
||||
refreshed_team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
refreshed_team = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": group_id}
|
||||
)
|
||||
if refreshed_team:
|
||||
@@ -1894,7 +1904,7 @@ async def patch_group(
|
||||
await _handle_group_membership_changes(group_id, current_members, final_members)
|
||||
|
||||
# Refresh team one more time to get final state after membership changes
|
||||
final_team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
final_team = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": group_id}
|
||||
)
|
||||
if final_team:
|
||||
|
||||
@@ -25,6 +25,14 @@ from litellm.proxy.management_endpoints.common_daily_activity import (
|
||||
get_daily_activity,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import handle_budget_for_entity
|
||||
from litellm.repositories.model_repository import ModelRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
DailyTagSpendRepository,
|
||||
TagRepository,
|
||||
)
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.types.tag_management import (
|
||||
TagConfig,
|
||||
TagDeleteRequest,
|
||||
@@ -56,7 +64,7 @@ async def _get_internal_user_api_keys(
|
||||
if user_id is None:
|
||||
return sorted(user_api_keys)
|
||||
|
||||
key_records = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
key_records = await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where={"user_id": user_id},
|
||||
select={"token": True},
|
||||
)
|
||||
@@ -109,7 +117,7 @@ async def _get_tag_daily_activity_api_key_filter(
|
||||
async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]:
|
||||
"""Helper function to get model names from model IDs"""
|
||||
try:
|
||||
models = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
models = await ModelRepository(prisma_client).table.find_many(
|
||||
where={"model_id": {"in": model_ids}}
|
||||
)
|
||||
return {model.model_id: model.model_name for model in models}
|
||||
@@ -189,7 +197,7 @@ async def new_tag(
|
||||
)
|
||||
try:
|
||||
# Check if tag already exists
|
||||
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
|
||||
existing_tag = await TagRepository(prisma_client).table.find_unique(
|
||||
where={"tag_name": tag.name}
|
||||
)
|
||||
if existing_tag is not None:
|
||||
@@ -210,7 +218,7 @@ async def new_tag(
|
||||
model_info = await _get_model_names(prisma_client, tag.models or [])
|
||||
|
||||
# Create new tag in database
|
||||
new_tag_record = await prisma_client.db.litellm_tagtable.create(
|
||||
new_tag_record = await TagRepository(prisma_client).table.create(
|
||||
data={
|
||||
"tag_name": tag.name,
|
||||
"description": tag.description,
|
||||
@@ -267,7 +275,7 @@ async def _add_tag_to_deployment(deployment: "Deployment", tag: str):
|
||||
|
||||
try:
|
||||
# Get current model from database to preserve encrypted fields
|
||||
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
db_model = await ModelRepository(prisma_client).table.find_unique(
|
||||
where={"model_id": deployment.model_info.id}
|
||||
)
|
||||
|
||||
@@ -292,7 +300,7 @@ async def _add_tag_to_deployment(deployment: "Deployment", tag: str):
|
||||
existing_params["tags"].append(tag)
|
||||
|
||||
# Update database with modified params (keeps encrypted fields encrypted)
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
await ModelRepository(prisma_client).table.update(
|
||||
where={"model_id": deployment.model_info.id},
|
||||
data={"litellm_params": json.dumps(existing_params)},
|
||||
)
|
||||
@@ -335,7 +343,7 @@ async def update_tag(
|
||||
|
||||
try:
|
||||
# Check if tag exists
|
||||
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
|
||||
existing_tag = await TagRepository(prisma_client).table.find_unique(
|
||||
where={"tag_name": tag.name}
|
||||
)
|
||||
if existing_tag is None:
|
||||
@@ -367,7 +375,7 @@ async def update_tag(
|
||||
update_data["budget_id"] = budget_id
|
||||
|
||||
# Update tag in database
|
||||
updated_tag_record = await prisma_client.db.litellm_tagtable.update(
|
||||
updated_tag_record = await TagRepository(prisma_client).table.update(
|
||||
where={"tag_name": tag.name},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -414,7 +422,7 @@ async def info_tag(
|
||||
|
||||
try:
|
||||
# Query tags from database with budget info
|
||||
tag_records = await prisma_client.db.litellm_tagtable.find_many(
|
||||
tag_records = await TagRepository(prisma_client).table.find_many(
|
||||
where={"tag_name": {"in": data.names}},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
@@ -535,7 +543,7 @@ async def list_tags(
|
||||
if start_date is not None and end_date is not None:
|
||||
dynamic_tag_where["date"] = {"gte": start_date, "lte": end_date}
|
||||
|
||||
dynamic_tag_rows = await prisma_client.db.litellm_dailytagspend.group_by(
|
||||
dynamic_tag_rows = await DailyTagSpendRepository(prisma_client).table.group_by(
|
||||
by=["tag"],
|
||||
where=dynamic_tag_where,
|
||||
min={"created_at": True},
|
||||
@@ -551,7 +559,7 @@ async def list_tags(
|
||||
)
|
||||
|
||||
## QUERY STORED TAGS ##
|
||||
tag_records = await prisma_client.db.litellm_tagtable.find_many(
|
||||
tag_records = await TagRepository(prisma_client).table.find_many(
|
||||
where=stored_tag_where,
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
@@ -626,14 +634,14 @@ async def delete_tag(
|
||||
|
||||
try:
|
||||
# Check if tag exists
|
||||
existing_tag = await prisma_client.db.litellm_tagtable.find_unique(
|
||||
existing_tag = await TagRepository(prisma_client).table.find_unique(
|
||||
where={"tag_name": data.name}
|
||||
)
|
||||
if existing_tag is None:
|
||||
raise HTTPException(status_code=404, detail=f"Tag {data.name} not found")
|
||||
|
||||
# Delete tag from database
|
||||
await prisma_client.db.litellm_tagtable.delete(where={"tag_name": data.name})
|
||||
await TagRepository(prisma_client).table.delete(where={"tag_name": data.name})
|
||||
|
||||
return {"message": f"Tag {data.name} deleted successfully"}
|
||||
except Exception as e:
|
||||
|
||||
@@ -30,6 +30,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars
|
||||
from litellm.proxy.management_endpoints.team_endpoints import _verify_team_access
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -249,7 +250,7 @@ async def add_team_callbacks(
|
||||
team_metadata = encrypt_callback_vars(team_metadata)
|
||||
team_metadata_json = json.dumps(team_metadata) # update team_metadata
|
||||
|
||||
new_team_row = await prisma_client.db.litellm_teamtable.update(
|
||||
new_team_row = await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore
|
||||
)
|
||||
|
||||
@@ -353,7 +354,7 @@ async def disable_team_logging(
|
||||
team_metadata_json = json.dumps(team_metadata)
|
||||
|
||||
# Update team in database
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
updated_team = await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ All /team management endpoints
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import json
|
||||
import math
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
@@ -102,6 +102,20 @@ from litellm.proxy.management_helpers.utils import (
|
||||
management_endpoint_wrapper,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy
|
||||
from litellm.repositories.budget_repository import BudgetRepository
|
||||
from litellm.repositories.organization_repository import OrganizationRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
AccessGroupRepository,
|
||||
DeletedTeamRepository,
|
||||
ModelTableRepository,
|
||||
OrganizationMembershipRepository,
|
||||
TeamMembershipRepository,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.router import Router
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
@@ -365,9 +379,9 @@ class TeamMemberBudgetHandler:
|
||||
return
|
||||
|
||||
# Batch-fetch existing memberships for this team (avoids N+1 queries)
|
||||
existing_memberships = await prisma_client.db.litellm_teammembership.find_many(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
existing_memberships = await TeamMembershipRepository(
|
||||
prisma_client
|
||||
).table.find_many(where={"team_id": team_id})
|
||||
existing_user_ids = {m.user_id for m in existing_memberships}
|
||||
|
||||
# Identify members with no existing membership row.
|
||||
@@ -386,7 +400,7 @@ class TeamMemberBudgetHandler:
|
||||
)
|
||||
|
||||
if missing:
|
||||
await prisma_client.db.litellm_teammembership.create_many(
|
||||
await TeamMembershipRepository(prisma_client).table.create_many(
|
||||
data=missing,
|
||||
skip_duplicates=True, # safety net against concurrent races
|
||||
)
|
||||
@@ -400,7 +414,7 @@ class TeamMemberBudgetHandler:
|
||||
# Heal existing membership rows that predate the team_member_budget
|
||||
# configuration: populate budget_id where it is currently NULL.
|
||||
# Rows with an explicit budget_id (per-member override) are left alone.
|
||||
updated = await prisma_client.db.litellm_teammembership.update_many(
|
||||
updated = await TeamMembershipRepository(prisma_client).table.update_many(
|
||||
where={"team_id": team_id, "budget_id": None},
|
||||
data={"budget_id": team_member_budget_id},
|
||||
)
|
||||
@@ -456,7 +470,7 @@ async def get_all_team_memberships(
|
||||
# else:
|
||||
# where_obj = {"user_id": str(user_id), "team_id": {"in": team_id}}
|
||||
|
||||
team_memberships = await prisma_client.db.litellm_teammembership.find_many(
|
||||
team_memberships = await TeamMembershipRepository(prisma_client).table.find_many(
|
||||
where=where_obj,
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
@@ -739,7 +753,7 @@ async def _check_org_team_limits(
|
||||
# calculate allocated tpm/rpm limit
|
||||
# check if specified tpm/rpm limit is greater than allocated tpm/rpm limit
|
||||
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"organization_id": org_table.organization_id},
|
||||
)
|
||||
|
||||
@@ -931,6 +945,9 @@ async def new_team( # noqa: PLR0915
|
||||
```
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
get_audit_log_changed_by,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
_license_check,
|
||||
create_audit_log_for_update,
|
||||
@@ -938,9 +955,6 @@ async def new_team( # noqa: PLR0915
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
get_audit_log_changed_by,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
@@ -986,7 +1000,7 @@ async def new_team( # noqa: PLR0915
|
||||
)
|
||||
|
||||
# Check if license is over limit
|
||||
total_teams = await prisma_client.db.litellm_teamtable.count()
|
||||
total_teams = await TeamRepository(prisma_client).table.count()
|
||||
if total_teams and _license_check.is_team_count_over_limit(
|
||||
team_count=total_teams
|
||||
):
|
||||
@@ -1092,7 +1106,7 @@ async def new_team( # noqa: PLR0915
|
||||
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
)
|
||||
model_dict = await prisma_client.db.litellm_modeltable.create(
|
||||
model_dict = await ModelTableRepository(prisma_client).table.create(
|
||||
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
@@ -1195,7 +1209,7 @@ async def new_team( # noqa: PLR0915
|
||||
db_data=complete_team_data_dict
|
||||
)
|
||||
|
||||
team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create(
|
||||
team_row: LiteLLM_TeamTable = await TeamRepository(prisma_client).table.create(
|
||||
data=complete_team_data_dict,
|
||||
include={"litellm_model_table": True}, # type: ignore
|
||||
)
|
||||
@@ -1315,11 +1329,11 @@ async def _update_model_table(
|
||||
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
)
|
||||
if model_id is None:
|
||||
model_dict = await prisma_client.db.litellm_modeltable.create(
|
||||
model_dict = await ModelTableRepository(prisma_client).table.create(
|
||||
data={**litellm_modeltable.json(exclude_none=True)} # type: ignore
|
||||
)
|
||||
else:
|
||||
model_dict = await prisma_client.db.litellm_modeltable.upsert(
|
||||
model_dict = await ModelTableRepository(prisma_client).table.upsert(
|
||||
where={"id": model_id},
|
||||
data={
|
||||
"update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
|
||||
@@ -1400,7 +1414,7 @@ async def fetch_and_validate_organization(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
|
||||
organization_row = await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
organization_row = await OrganizationRepository(prisma_client).table.find_unique(
|
||||
where={"organization_id": organization_id},
|
||||
include={"litellm_budget_table": True, "members": True, "teams": True},
|
||||
)
|
||||
@@ -1669,7 +1683,7 @@ async def update_team( # noqa: PLR0915
|
||||
},
|
||||
)
|
||||
|
||||
existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
existing_team_row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
|
||||
@@ -1738,7 +1752,9 @@ async def update_team( # noqa: PLR0915
|
||||
):
|
||||
# Is the caller org_admin of the destination org?
|
||||
caller_memberships = (
|
||||
await prisma_client.db.litellm_organizationmembership.find_many(
|
||||
await OrganizationMembershipRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
where={
|
||||
"user_id": user_api_key_dict.user_id,
|
||||
"organization_id": data.organization_id,
|
||||
@@ -1885,18 +1901,18 @@ async def update_team( # noqa: PLR0915
|
||||
updated_kv["router_settings"] = safe_dumps(updated_kv["router_settings"])
|
||||
|
||||
updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv)
|
||||
team_row: Optional[LiteLLM_TeamTable] = (
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id},
|
||||
data=updated_kv,
|
||||
# `object_permission` is included so `_refresh_cached_team`
|
||||
# doesn't write a cached team with the relation nulled out —
|
||||
# see team_model_add for the full rationale.
|
||||
include={
|
||||
"litellm_model_table": True,
|
||||
"object_permission": True,
|
||||
}, # type: ignore
|
||||
)
|
||||
team_row: Optional[LiteLLM_TeamTable] = await TeamRepository(
|
||||
prisma_client
|
||||
).table.update(
|
||||
where={"team_id": data.team_id},
|
||||
data=updated_kv,
|
||||
# `object_permission` is included so `_refresh_cached_team`
|
||||
# doesn't write a cached team with the relation nulled out —
|
||||
# see team_model_add for the full rationale.
|
||||
include={
|
||||
"litellm_model_table": True,
|
||||
"object_permission": True,
|
||||
}, # type: ignore
|
||||
)
|
||||
|
||||
if team_row is None or team_row.team_id is None:
|
||||
@@ -2306,7 +2322,7 @@ async def _add_team_members_to_team(
|
||||
|
||||
# ADD MEMBER TO TEAM
|
||||
_db_team_members = [m.model_dump() for m in complete_team_data.members_with_roles]
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
updated_team = await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": data.team_id},
|
||||
data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore
|
||||
)
|
||||
@@ -2377,7 +2393,7 @@ async def _validate_and_populate_member_user_info(
|
||||
|
||||
# Case 2: Only user_email provided - populate user_id from DB
|
||||
if member.user_email is not None and member.user_id is None:
|
||||
user_by_email = await prisma_client.db.litellm_usertable.find_first(
|
||||
user_by_email = await UserRepository(prisma_client).table.find_first(
|
||||
where={"user_email": {"equals": member.user_email, "mode": "insensitive"}}
|
||||
)
|
||||
|
||||
@@ -2410,7 +2426,7 @@ async def _validate_and_populate_member_user_info(
|
||||
|
||||
# Case 3: Only user_id provided - populate user_email from DB if user exists
|
||||
if member.user_id is not None and member.user_email is None:
|
||||
user_by_id = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_by_id = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": member.user_id}
|
||||
)
|
||||
|
||||
@@ -2608,7 +2624,7 @@ async def team_member_delete(
|
||||
detail={"error": "Either user_id or user_email needs to be passed in"},
|
||||
)
|
||||
|
||||
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
_existing_team_row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
|
||||
@@ -2652,7 +2668,7 @@ async def team_member_delete(
|
||||
|
||||
_db_new_team_members: List[dict] = [m.model_dump() for m in new_team_members]
|
||||
|
||||
_ = await prisma_client.db.litellm_teamtable.update(
|
||||
_ = await TeamRepository(prisma_client).table.update(
|
||||
where={
|
||||
"team_id": data.team_id,
|
||||
},
|
||||
@@ -2666,7 +2682,7 @@ async def team_member_delete(
|
||||
key_val["user_id"] = data.user_id
|
||||
elif data.user_email is not None:
|
||||
key_val["user_email"] = data.user_email
|
||||
existing_user_rows = await prisma_client.db.litellm_usertable.find_many(
|
||||
existing_user_rows = await UserRepository(prisma_client).table.find_many(
|
||||
where=key_val # type: ignore
|
||||
)
|
||||
|
||||
@@ -2678,7 +2694,7 @@ async def team_member_delete(
|
||||
if data.team_id in existing_user.teams:
|
||||
team_list = existing_user.teams
|
||||
team_list.remove(data.team_id)
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
await UserRepository(prisma_client).table.update(
|
||||
where={
|
||||
"user_id": existing_user.user_id,
|
||||
},
|
||||
@@ -2695,7 +2711,7 @@ async def team_member_delete(
|
||||
user_ids_to_delete.add(existing_user.user_id)
|
||||
|
||||
for _uid in user_ids_to_delete:
|
||||
await prisma_client.db.litellm_teammembership.delete_many(
|
||||
await TeamMembershipRepository(prisma_client).table.delete_many(
|
||||
where={"team_id": data.team_id, "user_id": _uid}
|
||||
)
|
||||
|
||||
@@ -2706,13 +2722,13 @@ async def team_member_delete(
|
||||
)
|
||||
|
||||
# Fetch keys before deletion to persist them
|
||||
keys_to_delete: List[LiteLLM_VerificationToken] = (
|
||||
await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={
|
||||
"user_id": {"in": list(user_ids_to_delete)},
|
||||
"team_id": data.team_id,
|
||||
}
|
||||
)
|
||||
keys_to_delete: List[
|
||||
LiteLLM_VerificationToken
|
||||
] = await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where={
|
||||
"user_id": {"in": list(user_ids_to_delete)},
|
||||
"team_id": data.team_id,
|
||||
}
|
||||
)
|
||||
|
||||
if keys_to_delete:
|
||||
@@ -2723,7 +2739,7 @@ async def team_member_delete(
|
||||
litellm_changed_by=None,
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_verificationtoken.delete_many(
|
||||
await VerificationTokenRepository(prisma_client).table.delete_many(
|
||||
where={
|
||||
"user_id": {"in": list(user_ids_to_delete)},
|
||||
"team_id": data.team_id,
|
||||
@@ -2818,7 +2834,7 @@ async def team_member_update(
|
||||
|
||||
_validate_budget_duration(data.budget_duration)
|
||||
|
||||
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
_existing_team_row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
|
||||
@@ -2921,7 +2937,7 @@ async def team_member_update(
|
||||
team_table.members_with_roles = team_members
|
||||
|
||||
_db_team_members: List[dict] = [m.model_dump() for m in team_members]
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": data.team_id},
|
||||
data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore
|
||||
)
|
||||
@@ -3052,7 +3068,7 @@ async def bulk_team_member_add(
|
||||
},
|
||||
)
|
||||
# get all users from the database
|
||||
all_users_in_db = await prisma_client.db.litellm_usertable.find_many(
|
||||
all_users_in_db = await UserRepository(prisma_client).table.find_many(
|
||||
order={"created_at": "desc"}
|
||||
)
|
||||
data.members = [
|
||||
@@ -3153,14 +3169,14 @@ async def delete_team(
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
get_audit_log_changed_by,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
create_audit_log_for_update,
|
||||
litellm_proxy_admin_name,
|
||||
prisma_client,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
get_audit_log_changed_by,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
@@ -3172,11 +3188,9 @@ async def delete_team(
|
||||
team_rows: List[LiteLLM_TeamTable] = []
|
||||
for team_id in data.team_ids:
|
||||
try:
|
||||
team_row_base: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
)
|
||||
team_row_base: Optional[BaseModel] = await TeamRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"team_id": team_id})
|
||||
if team_row_base is None:
|
||||
raise Exception
|
||||
except Exception:
|
||||
@@ -3243,11 +3257,9 @@ async def delete_team(
|
||||
_persist_deleted_verification_tokens,
|
||||
)
|
||||
|
||||
keys_to_delete: List[LiteLLM_VerificationToken] = (
|
||||
await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"team_id": {"in": data.team_ids}}
|
||||
)
|
||||
)
|
||||
keys_to_delete: List[LiteLLM_VerificationToken] = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_many(where={"team_id": {"in": data.team_ids}})
|
||||
|
||||
if keys_to_delete:
|
||||
await _persist_deleted_verification_tokens(
|
||||
@@ -3338,7 +3350,7 @@ async def _save_deleted_team_records(
|
||||
"""Save deleted team records to the database."""
|
||||
if not records:
|
||||
return
|
||||
await prisma_client.db.litellm_deletedteamtable.create_many(data=records)
|
||||
await DeletedTeamRepository(prisma_client).table.create_many(data=records)
|
||||
|
||||
|
||||
async def _persist_deleted_team_records(
|
||||
@@ -3420,7 +3432,7 @@ async def _add_team_member_budget_table(
|
||||
team_info_response_object: TeamInfoResponseObjectTeamTable,
|
||||
) -> TeamInfoResponseObjectTeamTable:
|
||||
try:
|
||||
team_budget = await prisma_client.db.litellm_budgettable.find_unique(
|
||||
team_budget = await BudgetRepository(prisma_client).table.find_unique(
|
||||
where={"budget_id": team_member_budget_id}
|
||||
)
|
||||
team_info_response_object.team_member_budget_table = team_budget
|
||||
@@ -3489,11 +3501,11 @@ async def team_info(
|
||||
)
|
||||
|
||||
try:
|
||||
team_info: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
team_info: Optional[BaseModel] = await TeamRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={"team_id": team_id},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
if team_info is None:
|
||||
raise Exception
|
||||
@@ -3749,7 +3761,7 @@ async def block_team(
|
||||
if prisma_client is None:
|
||||
raise Exception("No DB Connected.")
|
||||
|
||||
existing_team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
existing_team = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
if existing_team is None:
|
||||
@@ -3764,7 +3776,7 @@ async def block_team(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
record = await prisma_client.db.litellm_teamtable.update(
|
||||
record = await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": data.team_id}, data={"blocked": True} # type: ignore
|
||||
)
|
||||
|
||||
@@ -3801,7 +3813,7 @@ async def unblock_team(
|
||||
if prisma_client is None:
|
||||
raise Exception("No DB Connected.")
|
||||
|
||||
existing_team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
existing_team = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
if existing_team is None:
|
||||
@@ -3816,7 +3828,7 @@ async def unblock_team(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
record = await prisma_client.db.litellm_teamtable.update(
|
||||
record = await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": data.team_id}, data={"blocked": False} # type: ignore
|
||||
)
|
||||
|
||||
@@ -3849,7 +3861,7 @@ async def list_available_teams(
|
||||
return []
|
||||
|
||||
# filter out teams that the user is already a member of
|
||||
user_info = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_info = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id}
|
||||
)
|
||||
if user_info is None:
|
||||
@@ -3863,7 +3875,7 @@ async def list_available_teams(
|
||||
team for team in available_teams if team not in user_info_correct_type.teams
|
||||
]
|
||||
|
||||
available_teams_db = await prisma_client.db.litellm_teamtable.find_many(
|
||||
available_teams_db = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"team_id": {"in": available_teams}}
|
||||
)
|
||||
|
||||
@@ -4009,7 +4021,7 @@ async def _batch_resolve_access_group_resources(
|
||||
return {}
|
||||
|
||||
unique_ids = list(set(all_access_group_ids))
|
||||
rows = await _prisma_client.db.litellm_accessgrouptable.find_many(
|
||||
rows = await AccessGroupRepository(_prisma_client).table.find_many(
|
||||
where={"access_group_id": {"in": unique_ids}},
|
||||
)
|
||||
|
||||
@@ -4074,7 +4086,7 @@ async def _get_keys_count_by_team(
|
||||
if not page_team_ids:
|
||||
return {}
|
||||
|
||||
grouped = await prisma_client.db.litellm_verificationtoken.group_by(
|
||||
grouped = await VerificationTokenRepository(prisma_client).table.group_by(
|
||||
by=["team_id"],
|
||||
where={"team_id": {"in": page_team_ids}},
|
||||
count={"team_id": True},
|
||||
@@ -4288,25 +4300,25 @@ async def list_team_v2(
|
||||
|
||||
# Get teams with pagination
|
||||
if use_deleted_table:
|
||||
teams = await prisma_client.db.litellm_deletedteamtable.find_many(
|
||||
teams = await DeletedTeamRepository(prisma_client).table.find_many(
|
||||
where=where_conditions,
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
order=order_by if order_by else {"created_at": "desc"}, # Default sort
|
||||
)
|
||||
# Get total count for pagination
|
||||
total_count = await prisma_client.db.litellm_deletedteamtable.count(
|
||||
total_count = await DeletedTeamRepository(prisma_client).table.count(
|
||||
where=where_conditions
|
||||
)
|
||||
else:
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where=where_conditions,
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
order=order_by if order_by else {"created_at": "desc"}, # Default sort
|
||||
)
|
||||
# Get total count for pagination
|
||||
total_count = await prisma_client.db.litellm_teamtable.count(
|
||||
total_count = await TeamRepository(prisma_client).table.count(
|
||||
where=where_conditions
|
||||
)
|
||||
|
||||
@@ -4412,7 +4424,7 @@ async def _authorize_and_filter_teams(
|
||||
|
||||
if allowed_org_ids is not None:
|
||||
# Org admin: query DB for teams in their orgs
|
||||
org_teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
org_teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"organization_id": {"in": allowed_org_ids}},
|
||||
include={"litellm_model_table": True},
|
||||
)
|
||||
@@ -4427,7 +4439,7 @@ async def _authorize_and_filter_teams(
|
||||
]
|
||||
elif user_id:
|
||||
# Regular user: fetch all and filter by membership (Prisma can't filter JSON arrays)
|
||||
response = await prisma_client.db.litellm_teamtable.find_many(
|
||||
response = await TeamRepository(prisma_client).table.find_many(
|
||||
include={"litellm_model_table": True}
|
||||
)
|
||||
return [
|
||||
@@ -4439,7 +4451,7 @@ async def _authorize_and_filter_teams(
|
||||
else:
|
||||
# Proxy admin: all teams
|
||||
return list(
|
||||
await prisma_client.db.litellm_teamtable.find_many(
|
||||
await TeamRepository(prisma_client).table.find_many(
|
||||
include={"litellm_model_table": True}
|
||||
)
|
||||
)
|
||||
@@ -4500,7 +4512,7 @@ async def list_team(
|
||||
_team_memberships.append(tm)
|
||||
|
||||
# add all keys that belong to the team
|
||||
keys = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
keys = await VerificationTokenRepository(prisma_client).table.find_many(
|
||||
where={"team_id": team.team_id}
|
||||
)
|
||||
|
||||
@@ -4556,10 +4568,10 @@ async def get_paginated_teams(
|
||||
# Calculate skip for pagination
|
||||
skip = (page - 1) * page_size
|
||||
# Get total count
|
||||
total_count = await prisma_client.db.litellm_teamtable.count()
|
||||
total_count = await TeamRepository(prisma_client).table.count()
|
||||
|
||||
# Get paginated teams
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
teams = await TeamRepository(prisma_client).table.find_many(
|
||||
skip=skip, take=page_size, order={"team_alias": "asc"} # Sort by team_alias
|
||||
)
|
||||
return teams, total_count
|
||||
@@ -4632,7 +4644,7 @@ async def ui_view_teams(
|
||||
}
|
||||
|
||||
# Query users with pagination and filters
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where=where_conditions,
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
@@ -4704,7 +4716,7 @@ async def team_model_add(
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
# Get existing team
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team_row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
|
||||
@@ -4737,7 +4749,7 @@ async def team_model_add(
|
||||
# null them out — see object_permission_utils.validate_key_search_tools_against_team
|
||||
# and the MCP/agent authz paths, which treat a missing object_permission
|
||||
# as "no team-level restriction".
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
updated_team = await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": data.team_id},
|
||||
data={"models": updated_models},
|
||||
include={"object_permission": True}, # type: ignore
|
||||
@@ -4791,7 +4803,7 @@ async def team_model_delete(
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
# Get existing team
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team_row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
|
||||
@@ -4825,7 +4837,7 @@ async def team_model_delete(
|
||||
updated_models = [m for m in current_models if m not in data.models]
|
||||
|
||||
# Update team. See team_model_add for the rationale on `include`.
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
updated_team = await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": data.team_id},
|
||||
data={"models": updated_models},
|
||||
include={"object_permission": True}, # type: ignore
|
||||
@@ -4972,7 +4984,7 @@ async def update_team_member_permissions(
|
||||
},
|
||||
)
|
||||
# Update the team member permissions
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
updated_team = await TeamRepository(prisma_client).table.update(
|
||||
where={"team_id": data.team_id},
|
||||
data={"team_member_permissions": data.team_member_permissions},
|
||||
)
|
||||
@@ -5076,7 +5088,7 @@ async def _append_permissions_to_specific_teams(
|
||||
prisma_client, team_ids: List[str], permissions_to_add: set
|
||||
) -> int:
|
||||
"""Fetch specific teams by ID and append permissions."""
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"team_id": {"in": team_ids}},
|
||||
)
|
||||
|
||||
@@ -5108,7 +5120,7 @@ async def _append_permissions_to_all_teams(
|
||||
find_args["cursor"] = {"team_id": cursor}
|
||||
find_args["skip"] = 1
|
||||
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(**find_args)
|
||||
teams = await TeamRepository(prisma_client).table.find_many(**find_args)
|
||||
|
||||
if not teams:
|
||||
break
|
||||
@@ -5214,7 +5226,7 @@ async def get_team_daily_activity(
|
||||
where_condition = {}
|
||||
if team_ids_list:
|
||||
where_condition["team_id"] = {"in": list(team_ids_list)}
|
||||
team_aliases = await prisma_client.db.litellm_teamtable.find_many(
|
||||
team_aliases = await TeamRepository(prisma_client).table.find_many(
|
||||
where=where_condition
|
||||
)
|
||||
team_alias_metadata = {
|
||||
@@ -5251,9 +5263,9 @@ async def get_team_daily_activity(
|
||||
# If user does not have full team view, filter by their API keys
|
||||
if not has_full_team_view:
|
||||
# Get all API keys for this user
|
||||
user_keys = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"user_id": user_api_key_dict.user_id}
|
||||
)
|
||||
user_keys = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_many(where={"user_id": user_api_key_dict.user_id})
|
||||
user_api_keys = [key.token for key in user_keys if key.token]
|
||||
# If user has no API keys, return empty result
|
||||
if not user_api_keys:
|
||||
|
||||
@@ -21,6 +21,15 @@ if TYPE_CHECKING:
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
SpendLogsRepository,
|
||||
SpendLogToolIndexRepository,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.types.tool_management import (
|
||||
LiteLLM_ToolTableRow,
|
||||
ToolDetailResponse,
|
||||
@@ -256,8 +265,10 @@ async def get_tool_usage_logs(
|
||||
if end_time_filter is not None:
|
||||
where["start_time"]["lte"] = end_time_filter
|
||||
|
||||
total = await prisma_client.db.litellm_spendlogtoolindex.count(where=where)
|
||||
index_rows = await prisma_client.db.litellm_spendlogtoolindex.find_many(
|
||||
total = await SpendLogToolIndexRepository(prisma_client).table.count(
|
||||
where=where
|
||||
)
|
||||
index_rows = await SpendLogToolIndexRepository(prisma_client).table.find_many(
|
||||
where=where,
|
||||
order={"start_time": "desc"},
|
||||
skip=(page - 1) * page_size,
|
||||
@@ -269,7 +280,7 @@ async def get_tool_usage_logs(
|
||||
logs=[], total=total, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
spend_logs = await prisma_client.db.litellm_spendlogs.find_many(
|
||||
spend_logs = await SpendLogsRepository(prisma_client).table.find_many(
|
||||
where={"request_id": {"in": request_ids}}
|
||||
)
|
||||
log_by_id = {s.request_id: s for s in spend_logs}
|
||||
@@ -348,7 +359,7 @@ async def _resolve_key_hash_to_object_permission_id(
|
||||
hashed = key_hash if "sk-" not in (key_hash or "") else hash_token(key_hash)
|
||||
if not hashed:
|
||||
return None
|
||||
row = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
row = await VerificationTokenRepository(prisma_client).table.find_unique(
|
||||
where={"token": hashed}
|
||||
)
|
||||
if row is None:
|
||||
@@ -357,18 +368,18 @@ async def _resolve_key_hash_to_object_permission_id(
|
||||
if op_id:
|
||||
return op_id
|
||||
new_id = str(uuid.uuid4())
|
||||
await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
await ObjectPermissionRepository(prisma_client).table.create(
|
||||
data={"object_permission_id": new_id, "blocked_tools": []}
|
||||
)
|
||||
updated_count = await prisma_client.db.litellm_verificationtoken.update_many(
|
||||
updated_count = await VerificationTokenRepository(prisma_client).table.update_many(
|
||||
where={"token": hashed, "object_permission_id": None},
|
||||
data={"object_permission_id": new_id},
|
||||
)
|
||||
if updated_count == 0:
|
||||
await prisma_client.db.litellm_objectpermissiontable.delete(
|
||||
await ObjectPermissionRepository(prisma_client).table.delete(
|
||||
where={"object_permission_id": new_id}
|
||||
)
|
||||
row = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
row = await VerificationTokenRepository(prisma_client).table.find_unique(
|
||||
where={"token": hashed}
|
||||
)
|
||||
return getattr(row, "object_permission_id", None) if row else None
|
||||
@@ -383,7 +394,7 @@ async def _resolve_team_id_to_object_permission_id(
|
||||
if not team_id or not team_id.strip():
|
||||
return None
|
||||
team_id_clean = team_id.strip()
|
||||
row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id_clean},
|
||||
select={"object_permission_id": True},
|
||||
)
|
||||
@@ -393,18 +404,18 @@ async def _resolve_team_id_to_object_permission_id(
|
||||
if op_id:
|
||||
return op_id
|
||||
new_id = str(uuid.uuid4())
|
||||
await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
await ObjectPermissionRepository(prisma_client).table.create(
|
||||
data={"object_permission_id": new_id, "blocked_tools": []}
|
||||
)
|
||||
updated_count = await prisma_client.db.litellm_teamtable.update_many(
|
||||
updated_count = await TeamRepository(prisma_client).table.update_many(
|
||||
where={"team_id": team_id_clean, "object_permission_id": None},
|
||||
data={"object_permission_id": new_id},
|
||||
)
|
||||
if updated_count == 0:
|
||||
await prisma_client.db.litellm_objectpermissiontable.delete(
|
||||
await ObjectPermissionRepository(prisma_client).table.delete(
|
||||
where={"object_permission_id": new_id}
|
||||
)
|
||||
row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id_clean},
|
||||
select={"object_permission_id": True},
|
||||
)
|
||||
|
||||
@@ -15,8 +15,8 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
from html import escape
|
||||
from copy import deepcopy
|
||||
from html import escape
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -39,9 +39,9 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
import litellm
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.constants import (
|
||||
CLI_SSO_CLAIM_MAP,
|
||||
CLI_SSO_CLAIM_MAX_SCALAR_LENGTH,
|
||||
@@ -77,7 +77,6 @@ from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken, get_user_object
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
_get_request_ip_address,
|
||||
_has_user_setup_sso,
|
||||
@@ -92,6 +91,7 @@ from litellm.proxy.common_utils.html_forms.jwt_display_template import (
|
||||
jwt_display_template,
|
||||
)
|
||||
from litellm.proxy.common_utils.html_forms.ui_login import html_form
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||||
from litellm.proxy.management_endpoints.sso import CustomMicrosoftSSO
|
||||
from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||
@@ -110,6 +110,9 @@ from litellm.proxy.utils import (
|
||||
get_custom_url,
|
||||
get_server_root_path,
|
||||
)
|
||||
from litellm.repositories.table_repositories import SSOConfigRepository
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.secret_managers.main import get_secret_bool, str_to_bool
|
||||
from litellm.types.proxy.management_endpoints.ui_sso import * # noqa: F403, F401
|
||||
from litellm.types.proxy.management_endpoints.ui_sso import (
|
||||
@@ -438,7 +441,7 @@ async def _persist_cli_sso_user_metadata(
|
||||
return
|
||||
|
||||
try:
|
||||
user_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_row = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
existing_metadata: Dict[str, Any] = {}
|
||||
@@ -451,7 +454,7 @@ async def _persist_cli_sso_user_metadata(
|
||||
existing_metadata=existing_metadata,
|
||||
attribution_metadata=attribution_metadata,
|
||||
)
|
||||
await prisma_client.db.litellm_usertable.update_many(
|
||||
await UserRepository(prisma_client).table.update_many(
|
||||
where={"user_id": user_id},
|
||||
data={"metadata": merged_metadata},
|
||||
)
|
||||
@@ -859,7 +862,7 @@ async def google_login(
|
||||
if premium_user is not True:
|
||||
# Check if under 'free SSO user' limit
|
||||
if prisma_client is not None:
|
||||
total_users = await prisma_client.db.litellm_usertable.count()
|
||||
total_users = await UserRepository(prisma_client).table.count()
|
||||
if total_users and total_users > 5:
|
||||
raise ProxyException(
|
||||
message="You must be a LiteLLM Enterprise user to use SSO for more than 5 users. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://enterprise.litellm.ai/demo You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
||||
@@ -1150,7 +1153,7 @@ async def _setup_team_mappings() -> Optional["TeamMappings"]:
|
||||
"Prisma client is None, connect a database to your proxy"
|
||||
)
|
||||
|
||||
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
|
||||
sso_db_record = await SSOConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "sso_config"}
|
||||
)
|
||||
|
||||
@@ -1188,7 +1191,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]:
|
||||
"Prisma client is None, connect a database to your proxy"
|
||||
)
|
||||
|
||||
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
|
||||
sso_db_record = await SSOConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "sso_config"}
|
||||
)
|
||||
|
||||
@@ -1755,7 +1758,7 @@ async def _sync_user_role_from_jwt_role_map(
|
||||
|
||||
# Update existing DB record if role differs
|
||||
if user_info is not None and user_info.user_role != mapped_role.value:
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
await UserRepository(prisma_client).table.update(
|
||||
where={"user_id": user_info.user_id},
|
||||
data={"user_role": mapped_role.value},
|
||||
)
|
||||
@@ -1819,7 +1822,7 @@ async def check_and_update_if_proxy_admin_id(
|
||||
return user_role
|
||||
|
||||
if prisma_client:
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
await UserRepository(prisma_client).table.update(
|
||||
where={"user_id": user_id},
|
||||
data={"user_role": LitellmUserRoles.PROXY_ADMIN.value},
|
||||
)
|
||||
@@ -1976,7 +1979,7 @@ async def _fetch_cli_sso_team_details(
|
||||
team_details: List[Dict[str, Any]] = []
|
||||
try:
|
||||
if teams:
|
||||
prisma_teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
prisma_teams = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"team_id": {"in": teams}}
|
||||
)
|
||||
for team_row in prisma_teams:
|
||||
@@ -2884,7 +2887,7 @@ class SSOAuthenticationHandler:
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_usertable.update_many(
|
||||
await UserRepository(prisma_client).table.update_many(
|
||||
where={"user_id": user_id}, data=update_data
|
||||
)
|
||||
else:
|
||||
@@ -2986,7 +2989,7 @@ class SSOAuthenticationHandler:
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
try:
|
||||
team_obj = await prisma_client.db.litellm_teamtable.find_first(
|
||||
team_obj = await TeamRepository(prisma_client).table.find_first(
|
||||
where={"team_id": litellm_team_id}
|
||||
)
|
||||
verbose_proxy_logger.debug(f"Team object: {team_obj}")
|
||||
|
||||
@@ -19,6 +19,11 @@ from pydantic import BaseModel
|
||||
|
||||
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.repositories.table_repositories import DailyTagSpendRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
|
||||
# Constants for analytics periods
|
||||
MAX_DAYS = 7 # Number of days to show in DAU analytics
|
||||
@@ -676,7 +681,7 @@ async def get_per_user_analytics(
|
||||
where_clause["tag"] = {"contains": tag_filter}
|
||||
|
||||
# Get all tag records in the date range with optional tag filtering
|
||||
tag_records = await prisma_client.db.litellm_dailytagspend.find_many(
|
||||
tag_records = await DailyTagSpendRepository(prisma_client).table.find_many(
|
||||
where=where_clause
|
||||
)
|
||||
|
||||
@@ -693,9 +698,9 @@ async def get_per_user_analytics(
|
||||
)
|
||||
|
||||
# Lookup user_id for each api_key
|
||||
api_key_records = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"token": {"in": list(api_keys)}}
|
||||
)
|
||||
api_key_records = await VerificationTokenRepository(
|
||||
prisma_client
|
||||
).table.find_many(where={"token": {"in": list(api_keys)}})
|
||||
|
||||
# Create mapping from api_key to user_id
|
||||
api_key_to_user_id = {
|
||||
@@ -704,7 +709,7 @@ async def get_per_user_analytics(
|
||||
|
||||
# Get user emails for the user_ids
|
||||
user_ids = list(set(api_key_to_user_id.values()))
|
||||
user_records = await prisma_client.db.litellm_usertable.find_many(
|
||||
user_records = await UserRepository(prisma_client).table.find_many(
|
||||
where={"user_id": {"in": user_ids}}
|
||||
)
|
||||
|
||||
|
||||
@@ -27,6 +27,11 @@ from pydantic import BaseModel
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.repositories.table_repositories import (
|
||||
WorkflowEventRepository,
|
||||
WorkflowMessageRepository,
|
||||
WorkflowRunRepository,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -96,13 +101,13 @@ class WorkflowMessageCreateRequest(BaseModel):
|
||||
async def _get_next_sequence_number(prisma_client: Any, run_id: str, table: str) -> int:
|
||||
"""Return MAX(sequence_number) + 1 for the given run, for either events or messages."""
|
||||
if table == "events":
|
||||
rows = await prisma_client.db.litellm_workflowevent.find_many(
|
||||
rows = await WorkflowEventRepository(prisma_client).table.find_many(
|
||||
where={"run_id": run_id},
|
||||
order={"sequence_number": "desc"},
|
||||
take=1,
|
||||
)
|
||||
else:
|
||||
rows = await prisma_client.db.litellm_workflowmessage.find_many(
|
||||
rows = await WorkflowMessageRepository(prisma_client).table.find_many(
|
||||
where={"run_id": run_id},
|
||||
order={"sequence_number": "desc"},
|
||||
take=1,
|
||||
@@ -116,7 +121,7 @@ async def _require_run(
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
) -> Any:
|
||||
"""Return the run or raise 404. For non-admin callers, also enforce key ownership."""
|
||||
run = await prisma_client.db.litellm_workflowrun.find_unique(
|
||||
run = await WorkflowRunRepository(prisma_client).table.find_unique(
|
||||
where={"run_id": run_id}
|
||||
)
|
||||
if run is None:
|
||||
@@ -163,7 +168,7 @@ async def create_workflow_run(
|
||||
create_data["input"] = _json(data.input)
|
||||
if data.metadata is not None:
|
||||
create_data["metadata"] = _json(data.metadata)
|
||||
run = await prisma_client.db.litellm_workflowrun.create(data=create_data)
|
||||
run = await WorkflowRunRepository(prisma_client).table.create(data=create_data)
|
||||
return run
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error creating workflow run: %s", e)
|
||||
@@ -206,7 +211,7 @@ async def list_workflow_runs(
|
||||
where["created_by"] = caller
|
||||
|
||||
try:
|
||||
runs = await prisma_client.db.litellm_workflowrun.find_many(
|
||||
runs = await WorkflowRunRepository(prisma_client).table.find_many(
|
||||
where=where,
|
||||
order={"created_at": "desc"},
|
||||
take=limit,
|
||||
@@ -235,7 +240,7 @@ async def get_workflow_run(
|
||||
)
|
||||
|
||||
try:
|
||||
run = await prisma_client.db.litellm_workflowrun.find_unique(
|
||||
run = await WorkflowRunRepository(prisma_client).table.find_unique(
|
||||
where={"run_id": run_id},
|
||||
include={"events": {"order_by": {"sequence_number": "desc"}, "take": 1}},
|
||||
)
|
||||
@@ -286,7 +291,7 @@ async def update_workflow_run(
|
||||
await _require_run(prisma_client, run_id, user_api_key_dict)
|
||||
|
||||
try:
|
||||
run = await prisma_client.db.litellm_workflowrun.update(
|
||||
run = await WorkflowRunRepository(prisma_client).table.update(
|
||||
where={"run_id": run_id},
|
||||
data=update,
|
||||
)
|
||||
@@ -391,7 +396,7 @@ async def list_workflow_events(
|
||||
await _require_run(prisma_client, run_id, user_api_key_dict)
|
||||
|
||||
try:
|
||||
events = await prisma_client.db.litellm_workflowevent.find_many(
|
||||
events = await WorkflowEventRepository(prisma_client).table.find_many(
|
||||
where={"run_id": run_id},
|
||||
order={"sequence_number": "asc"},
|
||||
take=limit,
|
||||
@@ -436,7 +441,9 @@ async def append_workflow_message(
|
||||
}
|
||||
if data.session_id is not None:
|
||||
msg_data["session_id"] = data.session_id
|
||||
msg = await prisma_client.db.litellm_workflowmessage.create(data=msg_data)
|
||||
msg = await WorkflowMessageRepository(prisma_client).table.create(
|
||||
data=msg_data
|
||||
)
|
||||
return msg
|
||||
|
||||
except Exception as e:
|
||||
@@ -481,7 +488,7 @@ async def list_workflow_messages(
|
||||
await _require_run(prisma_client, run_id, user_api_key_dict)
|
||||
|
||||
try:
|
||||
messages = await prisma_client.db.litellm_workflowmessage.find_many(
|
||||
messages = await WorkflowMessageRepository(prisma_client).table.find_many(
|
||||
where={"run_id": run_id},
|
||||
order={"sequence_number": "asc"},
|
||||
take=limit,
|
||||
|
||||
@@ -18,6 +18,7 @@ from litellm.proxy._types import (
|
||||
Optional,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.repositories.table_repositories import AuditLogRepository
|
||||
from litellm.types.utils import StandardAuditLogPayload
|
||||
|
||||
_audit_log_callback_cache: Dict[str, CustomLogger] = {}
|
||||
@@ -244,7 +245,7 @@ async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs):
|
||||
_request_data = request_data.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
await prisma_client.db.litellm_auditlog.create(
|
||||
await AuditLogRepository(prisma_client).table.create(
|
||||
data={
|
||||
**_request_data, # type: ignore
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
|
||||
from litellm.repositories.table_repositories import MCPServerRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import (
|
||||
@@ -48,10 +50,10 @@ async def attach_object_permission_to_dict(
|
||||
|
||||
object_permission_id = data_dict.get("object_permission_id")
|
||||
if object_permission_id:
|
||||
object_permission = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
)
|
||||
object_permission = await ObjectPermissionRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
)
|
||||
if object_permission:
|
||||
# Convert to dict if needed
|
||||
@@ -106,10 +108,10 @@ async def handle_update_object_permission_common(
|
||||
)
|
||||
existing_object_permissions_dict: Dict = {}
|
||||
|
||||
existing_object_permission = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": object_permission_id_to_use},
|
||||
)
|
||||
existing_object_permission = await ObjectPermissionRepository(
|
||||
prisma_client
|
||||
).table.find_unique(
|
||||
where={"object_permission_id": object_permission_id_to_use},
|
||||
)
|
||||
|
||||
# Update the object permission
|
||||
@@ -137,14 +139,14 @@ async def handle_update_object_permission_common(
|
||||
#########################################################
|
||||
# Commit the update to the LiteLLM_ObjectPermissionTable
|
||||
#########################################################
|
||||
created_object_permission_row = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.upsert(
|
||||
where={"object_permission_id": object_permission_id_to_use},
|
||||
data={
|
||||
"create": existing_object_permissions_dict,
|
||||
"update": existing_object_permissions_dict,
|
||||
},
|
||||
)
|
||||
created_object_permission_row = await ObjectPermissionRepository(
|
||||
prisma_client
|
||||
).table.upsert(
|
||||
where={"object_permission_id": object_permission_id_to_use},
|
||||
data={
|
||||
"create": existing_object_permissions_dict,
|
||||
"update": existing_object_permissions_dict,
|
||||
},
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
@@ -183,7 +185,7 @@ async def _set_object_permission(
|
||||
clean_data["mcp_tool_permissions"]
|
||||
)
|
||||
|
||||
created_permission = await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
created_permission = await ObjectPermissionRepository(prisma_client).table.create(
|
||||
data=clean_data
|
||||
)
|
||||
|
||||
@@ -220,7 +222,7 @@ async def _get_db_mcp_servers_by_identifiers(
|
||||
return []
|
||||
|
||||
identifier_list = list(identifiers)
|
||||
return await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
return await MCPServerRepository(prisma_client).table.find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{"server_id": {"in": identifier_list}},
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import CommonProxyErrors, InvitationNew, UserAPIKeyAuth
|
||||
from litellm.repositories.table_repositories import InvitationLinkRepository
|
||||
|
||||
|
||||
async def create_invitation_for_user(
|
||||
@@ -25,7 +26,7 @@ async def create_invitation_for_user(
|
||||
expires_at = current_time + timedelta(days=7)
|
||||
|
||||
try:
|
||||
response = await prisma_client.db.litellm_invitationlink.create(
|
||||
response = await InvitationLinkRepository(prisma_client).table.create(
|
||||
data={
|
||||
"user_id": data.user_id,
|
||||
"created_at": current_time,
|
||||
|
||||
@@ -9,9 +9,8 @@ from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.otel.model.config import is_otel_v2_enabled
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
|
||||
from litellm.integrations.otel.model.config import is_otel_v2_enabled
|
||||
from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types
|
||||
BudgetNewRequest,
|
||||
DeleteCustomerRequest,
|
||||
@@ -32,7 +31,11 @@ from litellm.proxy._types import ( # key request types; user request types; tea
|
||||
VirtualKeyEvent,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.budget_repository import BudgetRepository
|
||||
from litellm.repositories.table_repositories import TeamMembershipRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
|
||||
|
||||
def get_new_internal_user_defaults(
|
||||
@@ -111,7 +114,7 @@ async def handle_budget_for_entity(
|
||||
budget_row.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
_budget = await BudgetRepository(prisma_client).table.create(
|
||||
data={
|
||||
**new_budget_data, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
@@ -174,7 +177,7 @@ async def _clone_team_default_budget_for_member(
|
||||
so the member starts with the team default's values but gets their own
|
||||
private budget row (which can be edited independently).
|
||||
"""
|
||||
default_budget = await prisma_client.db.litellm_budgettable.find_unique(
|
||||
default_budget = await BudgetRepository(prisma_client).table.find_unique(
|
||||
where={"budget_id": default_team_budget_id}
|
||||
)
|
||||
if default_budget is None:
|
||||
@@ -202,7 +205,7 @@ async def _clone_team_default_budget_for_member(
|
||||
cloned_data["budget_duration"]
|
||||
)
|
||||
|
||||
new_budget = await prisma_client.db.litellm_budgettable.create(data=cloned_data)
|
||||
new_budget = await BudgetRepository(prisma_client).table.create(data=cloned_data)
|
||||
return new_budget.budget_id
|
||||
|
||||
|
||||
@@ -229,7 +232,7 @@ async def add_new_member(
|
||||
## ADD TEAM ID, to USER TABLE IF NEW ##
|
||||
if new_member.user_id is not None:
|
||||
new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id)
|
||||
_returned_user = await prisma_client.db.litellm_usertable.upsert(
|
||||
_returned_user = await UserRepository(prisma_client).table.upsert(
|
||||
where={"user_id": new_member.user_id},
|
||||
data={
|
||||
"update": {"teams": {"push": [team_id]}},
|
||||
@@ -259,7 +262,7 @@ async def add_new_member(
|
||||
returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
|
||||
elif len(existing_user_row) == 1:
|
||||
user_info = existing_user_row[0]
|
||||
_returned_user = await prisma_client.db.litellm_usertable.update(
|
||||
_returned_user = await UserRepository(prisma_client).table.update(
|
||||
where={"user_id": user_info.user_id}, # type: ignore
|
||||
data={"teams": {"push": [team_id]}},
|
||||
)
|
||||
@@ -284,7 +287,7 @@ async def add_new_member(
|
||||
budget_data["max_budget"] = max_budget_in_team
|
||||
if allowed_models is not None:
|
||||
budget_data["allowed_models"] = allowed_models
|
||||
response = await prisma_client.db.litellm_budgettable.create(data=budget_data)
|
||||
response = await BudgetRepository(prisma_client).table.create(data=budget_data)
|
||||
|
||||
_budget_id = response.budget_id
|
||||
elif default_team_budget_id is not None:
|
||||
@@ -303,15 +306,15 @@ async def add_new_member(
|
||||
_budget_id = None
|
||||
|
||||
if _budget_id and returned_user is not None and returned_user.user_id is not None:
|
||||
_returned_team_membership = (
|
||||
await prisma_client.db.litellm_teammembership.create(
|
||||
data={
|
||||
"team_id": team_id,
|
||||
"user_id": returned_user.user_id,
|
||||
"budget_id": _budget_id,
|
||||
},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
_returned_team_membership = await TeamMembershipRepository(
|
||||
prisma_client
|
||||
).table.create(
|
||||
data={
|
||||
"team_id": team_id,
|
||||
"user_id": returned_user.user_id,
|
||||
"budget_id": _budget_id,
|
||||
},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
|
||||
returned_team_membership = LiteLLM_TeamMembership(
|
||||
|
||||
@@ -29,6 +29,8 @@ from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.repositories.table_repositories import MemoryRepository
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.types.memory_management import (
|
||||
LiteLLM_MemoryRow,
|
||||
MemoryCreateRequest,
|
||||
@@ -173,7 +175,7 @@ async def _is_team_admin_for(
|
||||
)
|
||||
|
||||
try:
|
||||
team_obj = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team_obj = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -304,7 +306,7 @@ async def create_memory(
|
||||
create_data["metadata"] = _serialize_metadata_for_prisma(body.metadata)
|
||||
|
||||
try:
|
||||
row = await prisma_client.db.litellm_memorytable.create(data=create_data)
|
||||
row = await MemoryRepository(prisma_client).table.create(data=create_data)
|
||||
except Exception as e:
|
||||
# Key is globally unique. Any duplicate → 409.
|
||||
if _is_unique_violation(e):
|
||||
@@ -364,8 +366,8 @@ async def list_memory(
|
||||
where = {"AND": [key_filter, vis]}
|
||||
|
||||
try:
|
||||
total = await prisma_client.db.litellm_memorytable.count(where=where)
|
||||
rows = await prisma_client.db.litellm_memorytable.find_many(
|
||||
total = await MemoryRepository(prisma_client).table.count(where=where)
|
||||
rows = await MemoryRepository(prisma_client).table.find_many(
|
||||
where=where,
|
||||
order={"updated_at": "desc"},
|
||||
skip=(page - 1) * page_size,
|
||||
@@ -386,7 +388,7 @@ async def _find_memory_for_caller(
|
||||
key_filter: dict = {"key": key}
|
||||
vis = _visibility_filter(user_api_key_dict)
|
||||
where: dict = key_filter if vis is None else {"AND": [key_filter, vis]}
|
||||
rows = await prisma_client.db.litellm_memorytable.find_many(
|
||||
rows = await MemoryRepository(prisma_client).table.find_many(
|
||||
where=where, take=1, order={"updated_at": "desc"}
|
||||
)
|
||||
if not rows:
|
||||
@@ -475,7 +477,7 @@ async def upsert_memory(
|
||||
# their team) — otherwise a teammate could overwrite a personal
|
||||
# entry through the OR-based visibility filter.
|
||||
await _assert_write_access(prisma_client, existing, user_api_key_dict)
|
||||
row = await prisma_client.db.litellm_memorytable.update(
|
||||
row = await MemoryRepository(prisma_client).table.update(
|
||||
where={"memory_id": existing.memory_id},
|
||||
data=data,
|
||||
)
|
||||
@@ -503,7 +505,7 @@ async def upsert_memory(
|
||||
if body.metadata is not None:
|
||||
create_data["metadata"] = _serialize_metadata_for_prisma(body.metadata)
|
||||
try:
|
||||
row = await prisma_client.db.litellm_memorytable.create(
|
||||
row = await MemoryRepository(prisma_client).table.create(
|
||||
data=create_data
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -524,7 +526,7 @@ async def upsert_memory(
|
||||
await _assert_write_access(
|
||||
prisma_client, existing_after_race, user_api_key_dict
|
||||
)
|
||||
row = await prisma_client.db.litellm_memorytable.update(
|
||||
row = await MemoryRepository(prisma_client).table.update(
|
||||
where={"memory_id": existing_after_race.memory_id},
|
||||
data=data,
|
||||
)
|
||||
@@ -554,7 +556,7 @@ async def delete_memory(
|
||||
# Visibility != write authority — see the upsert handler for the rationale.
|
||||
await _assert_write_access(prisma_client, row, user_api_key_dict)
|
||||
try:
|
||||
await prisma_client.db.litellm_memorytable.delete(
|
||||
await MemoryRepository(prisma_client).table.delete(
|
||||
where={"memory_id": row.memory_id}
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -5,6 +5,10 @@ from dataclasses import dataclass, field
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Union
|
||||
|
||||
from litellm.repositories.table_repositories import (
|
||||
ManagedFileRepository,
|
||||
ManagedObjectRepository,
|
||||
)
|
||||
from litellm.types.utils import SpecialEnums
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -697,7 +701,7 @@ async def resolve_input_file_id_to_unified(response, prisma_client) -> None:
|
||||
and prisma_client
|
||||
):
|
||||
try:
|
||||
managed_file = await prisma_client.db.litellm_managedfiletable.find_first(
|
||||
managed_file = await ManagedFileRepository(prisma_client).table.find_first(
|
||||
where={"flat_model_file_ids": {"has": response.input_file_id}}
|
||||
)
|
||||
if managed_file:
|
||||
@@ -719,7 +723,7 @@ async def resolve_output_file_ids_to_unified(response, prisma_client) -> None:
|
||||
if not raw_id or _is_base64_encoded_unified_file_id(raw_id):
|
||||
continue
|
||||
try:
|
||||
managed_file = await prisma_client.db.litellm_managedfiletable.find_first(
|
||||
managed_file = await ManagedFileRepository(prisma_client).table.find_first(
|
||||
where={"flat_model_file_ids": {"has": raw_id}}
|
||||
)
|
||||
if managed_file:
|
||||
@@ -821,6 +825,7 @@ async def get_batch_from_database(
|
||||
- response_batch: Parsed LiteLLMBatch object (or None)
|
||||
"""
|
||||
import json
|
||||
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
if managed_files_obj is None or not unified_batch_id:
|
||||
@@ -830,7 +835,7 @@ async def get_batch_from_database(
|
||||
if not prisma_client:
|
||||
return None, None
|
||||
|
||||
db_batch_object = await prisma_client.db.litellm_managedobjecttable.find_first(
|
||||
db_batch_object = await ManagedObjectRepository(prisma_client).table.find_first(
|
||||
where={"unified_object_id": batch_id}
|
||||
)
|
||||
|
||||
@@ -942,7 +947,7 @@ async def update_batch_in_database(
|
||||
update_data["batch_processed"] = True
|
||||
|
||||
try:
|
||||
await prisma_client.db.litellm_managedobjecttable.update(
|
||||
await ManagedObjectRepository(prisma_client).table.update(
|
||||
where={"unified_object_id": batch_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -958,7 +963,7 @@ async def update_batch_in_database(
|
||||
f"batch_processed column not found, retrying update without it: {col_err}"
|
||||
)
|
||||
update_data.pop("batch_processed", None)
|
||||
await prisma_client.db.litellm_managedobjecttable.update(
|
||||
await ManagedObjectRepository(prisma_client).table.update(
|
||||
where={"unified_object_id": batch_id},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from fastapi import (
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
|
||||
import litellm
|
||||
from litellm import CreateFileRequest, get_secret_str
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
@@ -37,15 +38,6 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_headers,
|
||||
get_custom_llm_provider_from_request_query,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging, is_known_model
|
||||
from litellm.router import Router
|
||||
from litellm.types.llms.openai import (
|
||||
CREATE_FILE_REQUESTS_PURPOSE,
|
||||
FileExpiresAfter,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilesPurpose,
|
||||
)
|
||||
|
||||
from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
_is_base64_encoded_unified_file_id,
|
||||
encode_file_id_with_model,
|
||||
@@ -54,6 +46,15 @@ from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
handle_model_based_routing,
|
||||
prepare_data_with_credentials,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging, is_known_model
|
||||
from litellm.repositories.table_repositories import ManagedFileRepository
|
||||
from litellm.router import Router
|
||||
from litellm.types.llms.openai import (
|
||||
CREATE_FILE_REQUESTS_PURPOSE,
|
||||
FileExpiresAfter,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilesPurpose,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -666,7 +667,7 @@ async def get_file_content( # noqa: PLR0915
|
||||
managed_files_obj, "prisma_client", None
|
||||
):
|
||||
prisma_client = getattr(managed_files_obj, "prisma_client")
|
||||
db_file = await prisma_client.db.litellm_managedfiletable.find_first(
|
||||
db_file = await ManagedFileRepository(prisma_client).table.find_first(
|
||||
where={"unified_file_id": file_id}
|
||||
)
|
||||
if db_file and db_file.storage_backend and db_file.storage_url:
|
||||
|
||||
@@ -43,6 +43,10 @@ from litellm.llms.base_llm.managed_resources.isolation import (
|
||||
can_access_resource,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.repositories.table_repositories import (
|
||||
ManagedFileRepository,
|
||||
ManagedObjectRepository,
|
||||
)
|
||||
from litellm.types.llms.openai import OpenAIFileObject
|
||||
|
||||
from .managed_id_codec import ManagedIdPayload, decode, is_managed, new_managed_id
|
||||
@@ -323,7 +327,7 @@ async def _resolve_one(
|
||||
)
|
||||
if not found and prisma_client is not None:
|
||||
try:
|
||||
db_row = await prisma_client.db.litellm_managedfiletable.find_first(
|
||||
db_row = await ManagedFileRepository(prisma_client).table.find_first(
|
||||
where={"unified_file_id": managed_id}
|
||||
)
|
||||
if db_row is not None:
|
||||
@@ -339,7 +343,7 @@ async def _resolve_one(
|
||||
# Object table (batches, responses)
|
||||
if prisma_client is not None:
|
||||
try:
|
||||
obj_row = await prisma_client.db.litellm_managedobjecttable.find_first(
|
||||
obj_row = await ManagedObjectRepository(prisma_client).table.find_first(
|
||||
where={"unified_object_id": managed_id}
|
||||
)
|
||||
if obj_row is not None:
|
||||
@@ -399,7 +403,7 @@ async def _guard_raw_provider_id(
|
||||
# id and scope to the current provider in the application layer (same as
|
||||
# _mint_or_reuse_file's dedup).
|
||||
try:
|
||||
candidates = await prisma_client.db.litellm_managedfiletable.find_many(
|
||||
candidates = await ManagedFileRepository(prisma_client).table.find_many(
|
||||
where={"flat_model_file_ids": {"has": raw_id}},
|
||||
)
|
||||
except Exception:
|
||||
@@ -425,7 +429,7 @@ async def _guard_raw_provider_id(
|
||||
# Object rows store model_object_id as "passthrough:{provider}:{raw}", so
|
||||
# the lookup is exact and already provider-scoped.
|
||||
try:
|
||||
existing = await prisma_client.db.litellm_managedobjecttable.find_first(
|
||||
existing = await ManagedObjectRepository(prisma_client).table.find_first(
|
||||
where={"model_object_id": f"passthrough:{provider}:{raw_id}"}
|
||||
)
|
||||
except Exception:
|
||||
@@ -492,7 +496,7 @@ async def _mint_or_reuse_file(
|
||||
# reuse a stable row instead of minting duplicate rows on every call.
|
||||
if prisma_client is not None:
|
||||
try:
|
||||
candidates = await prisma_client.db.litellm_managedfiletable.find_many(
|
||||
candidates = await ManagedFileRepository(prisma_client).table.find_many(
|
||||
where={"flat_model_file_ids": {"has": raw_id}},
|
||||
order={"created_at": "asc"},
|
||||
)
|
||||
@@ -627,7 +631,7 @@ async def _mint_or_reuse_object(
|
||||
# the batch's latest state (e.g. output_file_id / error_file_id that
|
||||
# were null at creation but populated once the batch completed).
|
||||
try:
|
||||
await prisma_client.db.litellm_managedobjecttable.update(
|
||||
await ManagedObjectRepository(prisma_client).table.update(
|
||||
where={"unified_object_id": existing.unified_object_id},
|
||||
data={
|
||||
"file_object": json.dumps(body_snapshot),
|
||||
@@ -647,7 +651,7 @@ async def _mint_or_reuse_object(
|
||||
|
||||
# Dedup: look up by the namespaced key — guaranteed unique per provider.
|
||||
try:
|
||||
existing = await prisma_client.db.litellm_managedobjecttable.find_first(
|
||||
existing = await ManagedObjectRepository(prisma_client).table.find_first(
|
||||
where={"model_object_id": namespaced_model_object_id}
|
||||
)
|
||||
except Exception:
|
||||
@@ -666,7 +670,7 @@ async def _mint_or_reuse_object(
|
||||
raw_id.split("_", 1)[0],
|
||||
)
|
||||
try:
|
||||
await prisma_client.db.litellm_managedobjecttable.upsert(
|
||||
await ManagedObjectRepository(prisma_client).table.upsert(
|
||||
where={"unified_object_id": managed_id},
|
||||
data={
|
||||
"create": {
|
||||
@@ -690,7 +694,7 @@ async def _mint_or_reuse_object(
|
||||
# the winner's managed ID so both callers converge on one ID instead of
|
||||
# the loser silently keeping the raw id.
|
||||
try:
|
||||
raced = await prisma_client.db.litellm_managedobjecttable.find_first(
|
||||
raced = await ManagedObjectRepository(prisma_client).table.find_first(
|
||||
where={"model_object_id": namespaced_model_object_id}
|
||||
)
|
||||
except Exception:
|
||||
@@ -883,9 +887,9 @@ async def _build_list_where_with_cursor(
|
||||
return where, fetch_order
|
||||
|
||||
cursor_table = (
|
||||
prisma_client.db.litellm_managedfiletable
|
||||
ManagedFileRepository(prisma_client).table
|
||||
if resource_kind == "files"
|
||||
else prisma_client.db.litellm_managedobjecttable
|
||||
else ManagedObjectRepository(prisma_client).table
|
||||
)
|
||||
cursor_field = (
|
||||
"unified_file_id" if resource_kind == "files" else "unified_object_id"
|
||||
@@ -932,12 +936,12 @@ async def _fetch_list_rows(
|
||||
# across rows that share a created_at timestamp.
|
||||
try:
|
||||
if resource_kind == "files":
|
||||
return await prisma_client.db.litellm_managedfiletable.find_many(
|
||||
return await ManagedFileRepository(prisma_client).table.find_many(
|
||||
where=where,
|
||||
order=[{"created_at": fetch_order}, {"unified_file_id": fetch_order}],
|
||||
take=fetch_limit,
|
||||
)
|
||||
return await prisma_client.db.litellm_managedobjecttable.find_many(
|
||||
return await ManagedObjectRepository(prisma_client).table.find_many(
|
||||
where={**where, "file_purpose": "batch"},
|
||||
order=[{"created_at": fetch_order}, {"unified_object_id": fetch_order}],
|
||||
take=fetch_limit,
|
||||
|
||||
@@ -60,12 +60,13 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
)
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
from litellm.proxy.utils import normalize_route_for_root_path
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
EndpointType,
|
||||
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
|
||||
LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY,
|
||||
EndpointType,
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
|
||||
@@ -1002,9 +1003,7 @@ async def pass_through_request( # noqa: PLR0915
|
||||
is_passthrough_list_route,
|
||||
list_passthrough_ids_from_db,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client as _list_prisma,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client as _list_prisma
|
||||
|
||||
if (
|
||||
is_passthrough_list_route(
|
||||
@@ -2875,7 +2874,7 @@ async def _filter_endpoints_by_team_allowed_routes(
|
||||
HTTPException: If team is not found
|
||||
"""
|
||||
# retrieve team from db
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id},
|
||||
)
|
||||
if team is None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.repositories.table_repositories import PolicyAttachmentRepository
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
PolicyAttachment,
|
||||
PolicyAttachmentCreateRequest,
|
||||
@@ -278,21 +279,21 @@ class AttachmentRegistry:
|
||||
PolicyAttachmentDBResponse with the created attachment
|
||||
"""
|
||||
try:
|
||||
created_attachment = (
|
||||
await prisma_client.db.litellm_policyattachmenttable.create(
|
||||
data={
|
||||
"policy_name": attachment_request.policy_name,
|
||||
"scope": attachment_request.scope,
|
||||
"teams": attachment_request.teams or [],
|
||||
"keys": attachment_request.keys or [],
|
||||
"models": attachment_request.models or [],
|
||||
"tags": attachment_request.tags or [],
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
}
|
||||
)
|
||||
created_attachment = await PolicyAttachmentRepository(
|
||||
prisma_client
|
||||
).table.create(
|
||||
data={
|
||||
"policy_name": attachment_request.policy_name,
|
||||
"scope": attachment_request.scope,
|
||||
"teams": attachment_request.teams or [],
|
||||
"keys": attachment_request.keys or [],
|
||||
"models": attachment_request.models or [],
|
||||
"tags": attachment_request.tags or [],
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
}
|
||||
)
|
||||
|
||||
# Also add to in-memory registry
|
||||
@@ -340,17 +341,15 @@ class AttachmentRegistry:
|
||||
"""
|
||||
try:
|
||||
# Get attachment before deleting
|
||||
attachment = (
|
||||
await prisma_client.db.litellm_policyattachmenttable.find_unique(
|
||||
where={"attachment_id": attachment_id}
|
||||
)
|
||||
)
|
||||
attachment = await PolicyAttachmentRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"attachment_id": attachment_id})
|
||||
|
||||
if attachment is None:
|
||||
raise Exception(f"Attachment with ID {attachment_id} not found")
|
||||
|
||||
# Delete from DB
|
||||
await prisma_client.db.litellm_policyattachmenttable.delete(
|
||||
await PolicyAttachmentRepository(prisma_client).table.delete(
|
||||
where={"attachment_id": attachment_id}
|
||||
)
|
||||
|
||||
@@ -379,11 +378,9 @@ class AttachmentRegistry:
|
||||
PolicyAttachmentDBResponse if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
attachment = (
|
||||
await prisma_client.db.litellm_policyattachmenttable.find_unique(
|
||||
where={"attachment_id": attachment_id}
|
||||
)
|
||||
)
|
||||
attachment = await PolicyAttachmentRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"attachment_id": attachment_id})
|
||||
|
||||
if attachment is None:
|
||||
return None
|
||||
@@ -419,10 +416,10 @@ class AttachmentRegistry:
|
||||
List of PolicyAttachmentDBResponse objects
|
||||
"""
|
||||
try:
|
||||
attachments = (
|
||||
await prisma_client.db.litellm_policyattachmenttable.find_many(
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
attachments = await PolicyAttachmentRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
|
||||
return [
|
||||
|
||||
@@ -12,6 +12,7 @@ from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.repositories.table_repositories import PolicyRepository
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
GuardrailPipeline,
|
||||
PipelineStep,
|
||||
@@ -295,7 +296,7 @@ class PolicyRegistry:
|
||||
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
|
||||
data["pipeline"] = json.dumps(validated_pipeline.model_dump())
|
||||
|
||||
created_policy = await prisma_client.db.litellm_policytable.create(
|
||||
created_policy = await PolicyRepository(prisma_client).table.create(
|
||||
data=data
|
||||
)
|
||||
|
||||
@@ -347,7 +348,7 @@ class PolicyRegistry:
|
||||
Exception: If policy is not in draft status (only drafts are editable).
|
||||
"""
|
||||
try:
|
||||
existing = await prisma_client.db.litellm_policytable.find_unique(
|
||||
existing = await PolicyRepository(prisma_client).table.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
if existing is None:
|
||||
@@ -382,7 +383,7 @@ class PolicyRegistry:
|
||||
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
|
||||
update_data["pipeline"] = json.dumps(validated_pipeline.model_dump())
|
||||
|
||||
updated_policy = await prisma_client.db.litellm_policytable.update(
|
||||
updated_policy = await PolicyRepository(prisma_client).table.update(
|
||||
where={"policy_id": policy_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -413,7 +414,7 @@ class PolicyRegistry:
|
||||
Dict with "message" and optional "warning" if production was deleted.
|
||||
"""
|
||||
try:
|
||||
policy = await prisma_client.db.litellm_policytable.find_unique(
|
||||
policy = await PolicyRepository(prisma_client).table.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
|
||||
@@ -424,7 +425,7 @@ class PolicyRegistry:
|
||||
policy_name = policy.policy_name
|
||||
|
||||
# Delete from DB
|
||||
await prisma_client.db.litellm_policytable.delete(
|
||||
await PolicyRepository(prisma_client).table.delete(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
|
||||
@@ -461,7 +462,7 @@ class PolicyRegistry:
|
||||
PolicyDBResponse if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
policy = await prisma_client.db.litellm_policytable.find_unique(
|
||||
policy = await PolicyRepository(prisma_client).table.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
|
||||
@@ -512,7 +513,7 @@ class PolicyRegistry:
|
||||
if version_status is not None:
|
||||
where["version_status"] = version_status
|
||||
|
||||
policies = await prisma_client.db.litellm_policytable.find_many(
|
||||
policies = await PolicyRepository(prisma_client).table.find_many(
|
||||
where=where if where else None,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
@@ -554,7 +555,7 @@ class PolicyRegistry:
|
||||
self.add_policy(policy_response.policy_name, policy)
|
||||
|
||||
self._policies_by_id = {}
|
||||
non_production = await prisma_client.db.litellm_policytable.find_many(
|
||||
non_production = await PolicyRepository(prisma_client).table.find_many(
|
||||
where={"version_status": {"in": ["draft", "published"]}},
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
@@ -654,7 +655,7 @@ class PolicyRegistry:
|
||||
PolicyVersionListResponse with policy_name and list of versions
|
||||
"""
|
||||
try:
|
||||
rows = await prisma_client.db.litellm_policytable.find_many(
|
||||
rows = await PolicyRepository(prisma_client).table.find_many(
|
||||
where={"policy_name": policy_name},
|
||||
order={"version_number": "desc"},
|
||||
)
|
||||
@@ -690,7 +691,7 @@ class PolicyRegistry:
|
||||
"""
|
||||
try:
|
||||
if source_policy_id is not None:
|
||||
source = await prisma_client.db.litellm_policytable.find_unique(
|
||||
source = await PolicyRepository(prisma_client).table.find_unique(
|
||||
where={"policy_id": source_policy_id}
|
||||
)
|
||||
if source is None:
|
||||
@@ -701,7 +702,7 @@ class PolicyRegistry:
|
||||
)
|
||||
else:
|
||||
# Find current production version for this policy_name
|
||||
prod = await prisma_client.db.litellm_policytable.find_first(
|
||||
prod = await PolicyRepository(prisma_client).table.find_first(
|
||||
where={
|
||||
"policy_name": policy_name,
|
||||
"version_status": "production",
|
||||
@@ -714,7 +715,7 @@ class PolicyRegistry:
|
||||
source = prod
|
||||
|
||||
# Next version number
|
||||
latest = await prisma_client.db.litellm_policytable.find_first(
|
||||
latest = await PolicyRepository(prisma_client).table.find_first(
|
||||
where={"policy_name": policy_name},
|
||||
order={"version_number": "desc"},
|
||||
)
|
||||
@@ -722,7 +723,7 @@ class PolicyRegistry:
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# Set is_latest=False on all existing versions for this policy_name
|
||||
await prisma_client.db.litellm_policytable.update_many(
|
||||
await PolicyRepository(prisma_client).table.update_many(
|
||||
where={"policy_name": policy_name},
|
||||
data={"is_latest": False},
|
||||
)
|
||||
@@ -758,7 +759,7 @@ class PolicyRegistry:
|
||||
else source.pipeline
|
||||
)
|
||||
|
||||
created = await prisma_client.db.litellm_policytable.create(data=data)
|
||||
created = await PolicyRepository(prisma_client).table.create(data=data)
|
||||
return _row_to_policy_db_response(created)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating new version: {e}")
|
||||
@@ -794,7 +795,7 @@ class PolicyRegistry:
|
||||
f"Invalid status '{new_status}'. Use 'published' or 'production'."
|
||||
)
|
||||
|
||||
row = await prisma_client.db.litellm_policytable.find_unique(
|
||||
row = await PolicyRepository(prisma_client).table.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
if row is None:
|
||||
@@ -809,7 +810,7 @@ class PolicyRegistry:
|
||||
raise Exception(
|
||||
f"Only draft versions can be published. Current status: '{current}'."
|
||||
)
|
||||
updated = await prisma_client.db.litellm_policytable.update(
|
||||
updated = await PolicyRepository(prisma_client).table.update(
|
||||
where={"policy_id": policy_id},
|
||||
data={
|
||||
"version_status": "published",
|
||||
@@ -832,7 +833,7 @@ class PolicyRegistry:
|
||||
)
|
||||
|
||||
# Demote current production to published
|
||||
await prisma_client.db.litellm_policytable.update_many(
|
||||
await PolicyRepository(prisma_client).table.update_many(
|
||||
where={
|
||||
"policy_name": policy_name,
|
||||
"version_status": "production",
|
||||
@@ -845,7 +846,7 @@ class PolicyRegistry:
|
||||
)
|
||||
|
||||
# Promote this version to production
|
||||
updated = await prisma_client.db.litellm_policytable.update(
|
||||
updated = await PolicyRepository(prisma_client).table.update(
|
||||
where={"policy_id": policy_id},
|
||||
data={
|
||||
"version_status": "production",
|
||||
@@ -895,10 +896,10 @@ class PolicyRegistry:
|
||||
PolicyVersionCompareResponse with both versions and field_diffs
|
||||
"""
|
||||
try:
|
||||
a = await prisma_client.db.litellm_policytable.find_unique(
|
||||
a = await PolicyRepository(prisma_client).table.find_unique(
|
||||
where={"policy_id": policy_id_a}
|
||||
)
|
||||
b = await prisma_client.db.litellm_policytable.find_unique(
|
||||
b = await PolicyRepository(prisma_client).table.find_unique(
|
||||
where={"policy_id": policy_id_b}
|
||||
)
|
||||
if a is None:
|
||||
@@ -950,7 +951,7 @@ class PolicyRegistry:
|
||||
Dict with success message
|
||||
"""
|
||||
try:
|
||||
await prisma_client.db.litellm_policytable.delete_many(
|
||||
await PolicyRepository(prisma_client).table.delete_many(
|
||||
where={"policy_name": policy_name}
|
||||
)
|
||||
self.remove_policy(policy_name)
|
||||
|
||||
@@ -16,6 +16,10 @@ from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
AttachmentImpactResponse,
|
||||
PolicyAttachmentCreateRequest,
|
||||
@@ -76,7 +80,7 @@ def _get_tags_from_metadata(metadata: object, json_metadata: object = None) -> l
|
||||
|
||||
async def _fetch_all_teams(prisma_client: object) -> list:
|
||||
"""Fetch teams from DB once. Reuse the result across tag and alias lookups."""
|
||||
return await prisma_client.db.litellm_teamtable.find_many( # type: ignore
|
||||
return await TeamRepository(prisma_client).table.find_many( # type: ignore
|
||||
where={},
|
||||
order={"created_at": "desc"},
|
||||
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
|
||||
@@ -159,7 +163,7 @@ async def _find_affected_by_team_patterns(
|
||||
new_keys: list = []
|
||||
unnamed_keys_count = 0
|
||||
if matched_team_ids:
|
||||
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
|
||||
keys = await VerificationTokenRepository(prisma_client).table.find_many( # type: ignore
|
||||
where={"team_id": {"in": matched_team_ids}},
|
||||
order={"created_at": "desc"},
|
||||
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
|
||||
@@ -182,7 +186,7 @@ async def _find_affected_keys_by_alias(
|
||||
|
||||
affected: list = []
|
||||
|
||||
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
|
||||
keys = await VerificationTokenRepository(prisma_client).table.find_many( # type: ignore
|
||||
where=_build_alias_where("key_alias", key_patterns),
|
||||
order={"created_at": "desc"},
|
||||
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
|
||||
@@ -367,7 +371,7 @@ async def estimate_attachment_impact(
|
||||
|
||||
# Tag-based impact
|
||||
if tag_patterns:
|
||||
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
|
||||
keys = await VerificationTokenRepository(prisma_client).table.find_many( # type: ignore
|
||||
where={},
|
||||
order={"created_at": "desc"},
|
||||
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
|
||||
|
||||
@@ -12,6 +12,10 @@ Validates:
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
Policy,
|
||||
PolicyValidationError,
|
||||
@@ -95,7 +99,7 @@ class PolicyValidator:
|
||||
return True # Can't validate without DB, assume valid
|
||||
|
||||
try:
|
||||
team = await self.prisma_client.db.litellm_teamtable.find_first(
|
||||
team = await TeamRepository(self.prisma_client).table.find_first(
|
||||
where={"team_alias": team_alias},
|
||||
)
|
||||
return team is not None
|
||||
@@ -119,7 +123,9 @@ class PolicyValidator:
|
||||
return True # Can't validate without DB, assume valid
|
||||
|
||||
try:
|
||||
key = await self.prisma_client.db.litellm_verificationtoken.find_first(
|
||||
key = await VerificationTokenRepository(
|
||||
self.prisma_client
|
||||
).table.find_first(
|
||||
where={"key_alias": key_alias},
|
||||
)
|
||||
return key is not None
|
||||
|
||||
@@ -22,6 +22,7 @@ from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKey
|
||||
from litellm.proxy.auth.auth_utils import is_request_body_safe
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.path_utils import safe_filename
|
||||
from litellm.repositories.table_repositories import PromptRepository
|
||||
from litellm.types.prompts.init_prompts import (
|
||||
ListPromptsResponse,
|
||||
PromptInfo,
|
||||
@@ -208,7 +209,7 @@ async def get_next_version_for_prompt(
|
||||
Returns:
|
||||
Next version number (1 if no versions exist, max_version + 1 otherwise)
|
||||
"""
|
||||
existing_prompts = await prisma_client.db.litellm_prompttable.find_many(
|
||||
existing_prompts = await PromptRepository(prisma_client).table.find_many(
|
||||
where={"prompt_id": prompt_id, "environment": environment}
|
||||
)
|
||||
|
||||
@@ -441,7 +442,7 @@ async def get_prompt_versions(
|
||||
where_clause: Dict[str, Any] = {"prompt_id": base_prompt_id}
|
||||
if environment:
|
||||
where_clause["environment"] = environment
|
||||
db_prompts = await prisma_client.db.litellm_prompttable.find_many(
|
||||
db_prompts = await PromptRepository(prisma_client).table.find_many(
|
||||
where=where_clause,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
@@ -612,7 +613,7 @@ async def get_prompt_info(
|
||||
# Query all environments this prompt exists in (lightweight: distinct on environment)
|
||||
all_environments: List[str] = []
|
||||
if prisma_client is not None:
|
||||
all_prompt_rows = await prisma_client.db.litellm_prompttable.find_many(
|
||||
all_prompt_rows = await PromptRepository(prisma_client).table.find_many(
|
||||
where={"prompt_id": base_prompt_id},
|
||||
distinct=["environment"],
|
||||
)
|
||||
@@ -634,7 +635,7 @@ async def get_prompt_info(
|
||||
}
|
||||
if requested_version is not None:
|
||||
where_clause["version"] = requested_version
|
||||
env_prompts = await prisma_client.db.litellm_prompttable.find_many(
|
||||
env_prompts = await PromptRepository(prisma_client).table.find_many(
|
||||
where=where_clause,
|
||||
order={"version": "desc"},
|
||||
take=1,
|
||||
@@ -752,7 +753,7 @@ async def create_prompt(
|
||||
)
|
||||
|
||||
# Store prompt in db with version
|
||||
prompt_db_entry = await prisma_client.db.litellm_prompttable.create(
|
||||
prompt_db_entry = await PromptRepository(prisma_client).table.create(
|
||||
data={
|
||||
"prompt_id": request.prompt_id,
|
||||
"version": new_version,
|
||||
@@ -848,7 +849,7 @@ async def update_prompt(
|
||||
)
|
||||
|
||||
# Check if any version of this prompt exists (in any environment)
|
||||
existing_prompts = await prisma_client.db.litellm_prompttable.find_many(
|
||||
existing_prompts = await PromptRepository(prisma_client).table.find_many(
|
||||
where={"prompt_id": base_prompt_id}
|
||||
)
|
||||
|
||||
@@ -877,7 +878,7 @@ async def update_prompt(
|
||||
)
|
||||
|
||||
# Store new version in db
|
||||
prompt_db_entry = await prisma_client.db.litellm_prompttable.create(
|
||||
prompt_db_entry = await PromptRepository(prisma_client).table.create(
|
||||
data={
|
||||
"prompt_id": base_prompt_id,
|
||||
"version": new_version,
|
||||
@@ -993,7 +994,7 @@ async def delete_prompt(
|
||||
delete_where["environment"] = environment
|
||||
|
||||
# Delete versions from the database (scoped to environment if provided)
|
||||
await prisma_client.db.litellm_prompttable.delete_many(where=delete_where)
|
||||
await PromptRepository(prisma_client).table.delete_many(where=delete_where)
|
||||
|
||||
# Remove matching prompts from memory — scope to environment if provided
|
||||
if environment:
|
||||
@@ -1105,7 +1106,7 @@ async def patch_prompt(
|
||||
if requested_version is not None:
|
||||
find_where["version"] = requested_version
|
||||
|
||||
db_rows = await prisma_client.db.litellm_prompttable.find_many(
|
||||
db_rows = await PromptRepository(prisma_client).table.find_many(
|
||||
where=find_where,
|
||||
order={"version": "desc"},
|
||||
take=1,
|
||||
@@ -1163,7 +1164,7 @@ async def patch_prompt(
|
||||
update_data["created_by"] = user_api_key_dict.user_id
|
||||
|
||||
# Update by primary key (id) to target exactly one row
|
||||
updated_prompt_db_entry = await prisma_client.db.litellm_prompttable.update(
|
||||
updated_prompt_db_entry = await PromptRepository(prisma_client).table.update(
|
||||
where={"id": target_row.id},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
@@ -48,6 +48,7 @@ from litellm.constants import (
|
||||
AIOHTTP_TTL_DNS_CACHE,
|
||||
AUDIO_SPEECH_CHUNK_SIZE,
|
||||
BASE_MCP_ROUTE,
|
||||
DAILY_TAG_SPEND_BATCH_MULTIPLIER,
|
||||
DEFAULT_MAX_RECURSE_DEPTH,
|
||||
DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL,
|
||||
DEFAULT_SHARED_HEALTH_CHECK_TTL,
|
||||
@@ -56,13 +57,13 @@ from litellm.constants import (
|
||||
LITELLM_SETTINGS_SAFE_DB_OVERRIDES,
|
||||
LITELLM_UI_ALLOW_HEADERS,
|
||||
LITELLM_UI_SESSION_DURATION,
|
||||
DAILY_TAG_SPEND_BATCH_MULTIPLIER,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import (
|
||||
UI_TEAM_ID,
|
||||
CallbackDelete,
|
||||
CallInfo,
|
||||
CommonProxyErrors,
|
||||
@@ -79,8 +80,8 @@ from litellm.proxy._types import (
|
||||
InvitationModel,
|
||||
InvitationNew,
|
||||
InvitationUpdate,
|
||||
Litellm_EntityType,
|
||||
LiteLLM_EndUserTable,
|
||||
Litellm_EntityType,
|
||||
LiteLLM_JWTAuth,
|
||||
LiteLLM_TagTable,
|
||||
LiteLLM_TeamTable,
|
||||
@@ -96,7 +97,6 @@ from litellm.proxy._types import (
|
||||
TeamDefaultSettings,
|
||||
TokenCountRequest,
|
||||
TransformRequestBody,
|
||||
UI_TEAM_ID,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.common_utils.cache_pydantic_utils import CacheCodec
|
||||
@@ -212,8 +212,6 @@ from litellm import Router
|
||||
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
||||
from litellm.caching.caching import DualCache, RedisCache
|
||||
from litellm.caching.redis_cluster_cache import RedisClusterCache
|
||||
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.constants import (
|
||||
_REALTIME_BODY_CACHE_SIZE,
|
||||
APSCHEDULER_COALESCE,
|
||||
@@ -247,8 +245,8 @@ from litellm.litellm_core_utils.sensitive_data_masker import (
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy._lazy_features import attach_lazy_features
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||
router as analytics_router,
|
||||
)
|
||||
@@ -308,6 +306,8 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
from litellm.proxy.common_utils.proxy_state import ProxyState
|
||||
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
|
||||
from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES
|
||||
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
|
||||
from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.container_endpoints.endpoints import router as container_router
|
||||
from litellm.proxy.credential_endpoints.endpoints import router as credential_router
|
||||
from litellm.proxy.db.db_transaction_queue.spend_log_cleanup import SpendLogCleanup
|
||||
@@ -361,7 +361,9 @@ from litellm.proxy.management_endpoints.fallback_management_endpoints import (
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
router as internal_user_router,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
user_update,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
delete_verification_tokens,
|
||||
duration_in_seconds,
|
||||
@@ -398,10 +400,6 @@ from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
update_team,
|
||||
validate_membership,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.workflow_management_endpoints import (
|
||||
router as workflow_management_router,
|
||||
)
|
||||
from litellm.proxy.memory.memory_endpoints import router as memory_router
|
||||
from litellm.proxy.management_endpoints.ui_sso import (
|
||||
get_disabled_non_admin_personal_key_creation,
|
||||
)
|
||||
@@ -409,7 +407,11 @@ from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
|
||||
from litellm.proxy.management_endpoints.user_agent_analytics_endpoints import (
|
||||
router as user_agent_analytics_router,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.workflow_management_endpoints import (
|
||||
router as workflow_management_router,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
|
||||
from litellm.proxy.memory.memory_endpoints import router as memory_router
|
||||
from litellm.proxy.middleware.in_flight_requests_middleware import (
|
||||
InFlightRequestsMiddleware,
|
||||
)
|
||||
@@ -417,12 +419,13 @@ from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMi
|
||||
from litellm.proxy.middleware.request_size_limit_middleware import (
|
||||
RequestSizeLimitMiddleware,
|
||||
)
|
||||
from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager
|
||||
from litellm.proxy.ocr_endpoints.endpoints import router as ocr_router
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||
router as openai_files_router,
|
||||
)
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||
set_files_config,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
@@ -444,6 +447,7 @@ from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router
|
||||
from litellm.proxy.response_api_endpoints.endpoints import router as response_router
|
||||
from litellm.proxy.route_llm_request import route_request
|
||||
from litellm.proxy.search_endpoints.endpoints import router as search_router
|
||||
from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager
|
||||
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
||||
router as spend_management_router,
|
||||
)
|
||||
@@ -478,6 +482,7 @@ from litellm.proxy.utils import (
|
||||
update_spend,
|
||||
)
|
||||
from litellm.proxy.video_endpoints.endpoints import router as video_router
|
||||
from litellm.repositories.credentials_repository import CredentialsRepository
|
||||
from litellm.router import (
|
||||
AssistantsTypedDict,
|
||||
Deployment,
|
||||
@@ -511,7 +516,9 @@ from litellm.types.proxy.management_endpoints.ui_sso import (
|
||||
LiteLLM_UpperboundKeyGenerateParams,
|
||||
)
|
||||
from litellm.types.realtime import RealtimeQueryParams
|
||||
from litellm.types.router import DeploymentTypedDict
|
||||
from litellm.types.router import (
|
||||
DeploymentTypedDict,
|
||||
)
|
||||
from litellm.types.router import ModelInfo as RouterModelInfo
|
||||
from litellm.types.router import (
|
||||
RouterGeneralSettings,
|
||||
@@ -4248,13 +4255,13 @@ class ProxyConfig:
|
||||
)
|
||||
setattr(litellm, key, value)
|
||||
if key in {"s3_audit_callback_params", "s3_callback_params"}:
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
reset_audit_log_callback_cache,
|
||||
)
|
||||
from litellm.integrations.s3_v2 import S3Logger as S3V2Logger
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_in_memory_loggers,
|
||||
)
|
||||
from litellm.integrations.s3_v2 import S3Logger as S3V2Logger
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
reset_audit_log_callback_cache,
|
||||
)
|
||||
|
||||
reset_audit_log_callback_cache()
|
||||
_in_memory_loggers[:] = [
|
||||
@@ -5335,7 +5342,7 @@ class ProxyConfig:
|
||||
4. Update router settings
|
||||
"""
|
||||
if llm_router is not None and prisma_client is not None:
|
||||
db_router_settings = await prisma_client.db.litellm_config.find_first(
|
||||
db_router_settings = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": "router_settings"}
|
||||
)
|
||||
|
||||
@@ -5761,7 +5768,7 @@ class ProxyConfig:
|
||||
|
||||
async def _get_models_from_db(self, prisma_client: PrismaClient) -> list:
|
||||
try:
|
||||
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
new_models = await ModelRepository(prisma_client).table.find_many()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy_server.py::add_deployment() - Error getting new models from DB - {}".format(
|
||||
@@ -5975,7 +5982,7 @@ class ProxyConfig:
|
||||
"""
|
||||
|
||||
try:
|
||||
sso_settings = await prisma_client.db.litellm_ssoconfig.find_unique(
|
||||
sso_settings = await SSOConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "sso_config"}
|
||||
)
|
||||
if sso_settings is not None:
|
||||
@@ -6011,9 +6018,9 @@ class ProxyConfig:
|
||||
)
|
||||
|
||||
try:
|
||||
db_record = await prisma_client.db.litellm_configoverrides.find_unique(
|
||||
where={"config_type": "hashicorp_vault"}
|
||||
)
|
||||
db_record = await ConfigOverridesRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"config_type": "hashicorp_vault"})
|
||||
|
||||
if db_record is None or db_record.config_value is None:
|
||||
if self._last_hashicorp_vault_config is not None:
|
||||
@@ -6130,7 +6137,7 @@ class ProxyConfig:
|
||||
last_model_cost_map_reload = current_time.isoformat()
|
||||
|
||||
# Clear force reload flag in database
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "model_cost_map_reload_config"},
|
||||
data={
|
||||
"create": {
|
||||
@@ -6239,7 +6246,7 @@ class ProxyConfig:
|
||||
last_anthropic_beta_headers_reload = current_time.isoformat()
|
||||
|
||||
# Clear force reload flag in database
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "anthropic_beta_headers_reload_config"},
|
||||
data={
|
||||
"create": {
|
||||
@@ -6299,7 +6306,7 @@ class ProxyConfig:
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
|
||||
try:
|
||||
prompts_in_db = await prisma_client.db.litellm_prompttable.find_many()
|
||||
prompts_in_db = await PromptRepository(prisma_client).table.find_many()
|
||||
for prompt in prompts_in_db:
|
||||
# Convert DB object to dict and create versioned prompt_id
|
||||
prompt_spec = self._get_prompt_spec_for_db_prompt(db_prompt=prompt)
|
||||
@@ -6589,7 +6596,7 @@ class ProxyConfig:
|
||||
|
||||
async def get_credentials(self, prisma_client: PrismaClient):
|
||||
try:
|
||||
credentials = await prisma_client.db.litellm_credentialstable.find_many()
|
||||
credentials = await CredentialsRepository(prisma_client).find_all()
|
||||
credentials = [self.decrypt_credentials(cred) for cred in credentials]
|
||||
await self.delete_credentials(
|
||||
credentials
|
||||
@@ -7352,7 +7359,7 @@ class ProxyStartupEvent:
|
||||
# spend cap blocks forever once it's hit.
|
||||
if prisma_client is not None and litellm.budget_duration is not None:
|
||||
try:
|
||||
await prisma_client.db.litellm_usertable.update_many(
|
||||
await UserRepository(prisma_client).table.update_many(
|
||||
where={
|
||||
"user_id": litellm_proxy_budget_name,
|
||||
"budget_reset_at": None,
|
||||
@@ -7420,7 +7427,7 @@ class ProxyStartupEvent:
|
||||
|
||||
if prisma_client is None:
|
||||
return
|
||||
db_record = await prisma_client.db.litellm_uisettings.find_unique(
|
||||
db_record = await UISettingsRepository(prisma_client).table.find_unique(
|
||||
where={"id": "ui_settings"}
|
||||
)
|
||||
if db_record and db_record.ui_settings:
|
||||
@@ -7569,7 +7576,7 @@ class ProxyStartupEvent:
|
||||
# but YAML config has False.
|
||||
if store_model_in_db is not True and prisma_client is not None:
|
||||
try:
|
||||
_db_gs_record = await prisma_client.db.litellm_config.find_first(
|
||||
_db_gs_record = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
if _db_gs_record is not None and isinstance(
|
||||
@@ -10361,6 +10368,18 @@ async def run_thread(
|
||||
# )
|
||||
# async def get_available_routes(user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.repositories.config_repository import ConfigRepository
|
||||
from litellm.repositories.model_repository import ModelRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
AccessGroupRepository,
|
||||
ConfigOverridesRepository,
|
||||
InvitationLinkRepository,
|
||||
PromptRepository,
|
||||
SSOConfigRepository,
|
||||
UISettingsRepository,
|
||||
)
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.user_repository import UserRepository
|
||||
|
||||
|
||||
def _get_provider_token_counter(
|
||||
@@ -10666,7 +10685,7 @@ async def _check_if_model_is_user_added(
|
||||
id = model.get("model_info", {}).get("id", None)
|
||||
if id is None:
|
||||
continue
|
||||
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
db_model = await ModelRepository(prisma_client).table.find_unique(
|
||||
where={"model_id": id}
|
||||
)
|
||||
if db_model is not None:
|
||||
@@ -10723,7 +10742,7 @@ async def non_admin_all_models(
|
||||
|
||||
if user_api_key_dict.user_id:
|
||||
try:
|
||||
user_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_row = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id}
|
||||
)
|
||||
except Exception:
|
||||
@@ -10823,7 +10842,7 @@ async def _add_access_group_models_to_team_models(
|
||||
return team_models
|
||||
|
||||
# Single batch fetch for all access groups
|
||||
access_group_rows = await prisma_client.db.litellm_accessgrouptable.find_many(
|
||||
access_group_rows = await AccessGroupRepository(prisma_client).table.find_many(
|
||||
where={"access_group_id": {"in": list(all_access_group_ids)}}
|
||||
)
|
||||
ag_model_map: Dict[str, List[str]] = {
|
||||
@@ -10865,13 +10884,13 @@ async def get_all_team_models(
|
||||
team_db_objects_typed: List[LiteLLM_TeamTable] = []
|
||||
|
||||
if user_teams == "*":
|
||||
team_db_objects = await prisma_client.db.litellm_teamtable.find_many()
|
||||
team_db_objects = await TeamRepository(prisma_client).table.find_many()
|
||||
team_db_objects_typed = [
|
||||
LiteLLM_TeamTable(**team_db_object.model_dump())
|
||||
for team_db_object in team_db_objects
|
||||
]
|
||||
else:
|
||||
team_db_objects = await prisma_client.db.litellm_teamtable.find_many(
|
||||
team_db_objects = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"team_id": {"in": user_teams}}
|
||||
)
|
||||
|
||||
@@ -10938,7 +10957,7 @@ async def get_all_team_and_direct_access_models(
|
||||
exclude_team_models=True
|
||||
) # has access to all models
|
||||
elif user_api_key_dict.user_id is not None:
|
||||
user_db_object = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_db_object = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id}
|
||||
)
|
||||
if user_db_object is not None:
|
||||
@@ -11082,7 +11101,7 @@ async def _get_caller_byok_team_scope(
|
||||
if user_id is None:
|
||||
return set()
|
||||
try:
|
||||
user_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_row = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
except Exception:
|
||||
@@ -11143,13 +11162,13 @@ async def _fetch_db_models_for_search(
|
||||
else:
|
||||
take_limit = max(0, page * size - router_models_count)
|
||||
|
||||
db_models_total_count = await prisma_client.db.litellm_proxymodeltable.count(
|
||||
db_models_total_count = await ModelRepository(prisma_client).table.count(
|
||||
where=db_where_condition
|
||||
)
|
||||
|
||||
db_models_raw: list = []
|
||||
if take_limit > 0:
|
||||
db_models_raw = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
db_models_raw = await ModelRepository(prisma_client).table.find_many(
|
||||
where=db_where_condition,
|
||||
take=take_limit,
|
||||
)
|
||||
@@ -11484,7 +11503,7 @@ async def _load_team_object_for_model_filter(
|
||||
) -> Optional[LiteLLM_TeamTable]:
|
||||
"""Load team row from DB; returns None if missing or on error."""
|
||||
try:
|
||||
team_db_object = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team_db_object = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
if team_db_object is None:
|
||||
@@ -11547,7 +11566,7 @@ async def _gather_team_accessible_model_ids(
|
||||
_resolved_names = _team_models_resolve_to_names(
|
||||
team_object.models, access_groups
|
||||
)
|
||||
db_models = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
db_models = await ModelRepository(prisma_client).table.find_many(
|
||||
where={"model_name": {"in": _resolved_names}}
|
||||
)
|
||||
for db_model in db_models:
|
||||
@@ -11586,7 +11605,7 @@ async def _authorize_team_id_query(
|
||||
detail={"error": "Not authorized to view this team's models"},
|
||||
)
|
||||
try:
|
||||
user_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_row = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
except Exception:
|
||||
@@ -11693,7 +11712,7 @@ async def _find_model_by_id(
|
||||
# If not found in config, search in database
|
||||
if found_model is None:
|
||||
try:
|
||||
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
db_model = await ModelRepository(prisma_client).table.find_unique(
|
||||
where={"model_id": model_id}
|
||||
)
|
||||
if db_model:
|
||||
@@ -12845,7 +12864,7 @@ async def alerting_settings(
|
||||
)
|
||||
|
||||
## get general settings from db
|
||||
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
||||
db_general_settings = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
|
||||
@@ -13384,7 +13403,7 @@ async def onboarding(invite_link: str, request: Request):
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
invite_obj = await prisma_client.db.litellm_invitationlink.find_unique(
|
||||
invite_obj = await InvitationLinkRepository(prisma_client).table.find_unique(
|
||||
where={"id": invite_link}
|
||||
)
|
||||
if invite_obj is None:
|
||||
@@ -13408,7 +13427,7 @@ async def onboarding(invite_link: str, request: Request):
|
||||
)
|
||||
|
||||
### GET USER OBJECT ###
|
||||
user_obj = await prisma_client.db.litellm_usertable.find_unique(
|
||||
user_obj = await UserRepository(prisma_client).table.find_unique(
|
||||
where={"user_id": invite_obj.user_id}
|
||||
)
|
||||
|
||||
@@ -13513,7 +13532,7 @@ async def _rollback_onboarding_invite_claim(
|
||||
return
|
||||
|
||||
try:
|
||||
await prisma_client.db.litellm_invitationlink.update_many(
|
||||
await InvitationLinkRepository(prisma_client).table.update_many(
|
||||
where={"id": invitation_link, "is_accepted": True},
|
||||
data={
|
||||
"accepted_at": None,
|
||||
@@ -13547,10 +13566,10 @@ async def _generate_onboarding_ui_session_token(user_obj: Any) -> str:
|
||||
)
|
||||
key = response["token"] # type: ignore
|
||||
|
||||
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
|
||||
|
||||
import jwt
|
||||
|
||||
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
|
||||
|
||||
disabled_non_admin_personal_key_creation = (
|
||||
get_disabled_non_admin_personal_key_creation()
|
||||
)
|
||||
@@ -13596,7 +13615,7 @@ async def claim_onboarding_link(data: InvitationClaim, request: Request):
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
invite_obj = await prisma_client.db.litellm_invitationlink.find_unique(
|
||||
invite_obj = await InvitationLinkRepository(prisma_client).table.find_unique(
|
||||
where={"id": data.invitation_link}
|
||||
)
|
||||
if invite_obj is None:
|
||||
@@ -13956,7 +13975,7 @@ async def invitation_info(
|
||||
},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_invitationlink.find_unique(
|
||||
response = await InvitationLinkRepository(prisma_client).table.find_unique(
|
||||
where={"id": invitation_id}
|
||||
)
|
||||
|
||||
@@ -14010,7 +14029,7 @@ async def invitation_update(
|
||||
)
|
||||
|
||||
current_time = litellm.utils.get_utc_datetime()
|
||||
response = await prisma_client.db.litellm_invitationlink.update(
|
||||
response = await InvitationLinkRepository(prisma_client).table.update(
|
||||
where={"id": data.invitation_id},
|
||||
data={
|
||||
"id": data.invitation_id,
|
||||
@@ -14081,7 +14100,7 @@ async def invitation_delete(
|
||||
|
||||
# Org admins can only delete invitations they created
|
||||
if is_other_admin and not is_proxy_admin:
|
||||
invitation = await prisma_client.db.litellm_invitationlink.find_unique(
|
||||
invitation = await InvitationLinkRepository(prisma_client).table.find_unique(
|
||||
where={"id": data.invitation_id}
|
||||
)
|
||||
if invitation is None:
|
||||
@@ -14097,7 +14116,7 @@ async def invitation_delete(
|
||||
},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_invitationlink.delete(
|
||||
response = await InvitationLinkRepository(prisma_client).table.delete(
|
||||
where={"id": data.invitation_id}
|
||||
)
|
||||
|
||||
@@ -14139,7 +14158,7 @@ async def update_config( # noqa: PLR0915
|
||||
raise Exception("No DB Connected")
|
||||
|
||||
async def _read_section(param_name: str) -> dict:
|
||||
row = await prisma_client.db.litellm_config.find_first(
|
||||
row = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": param_name}
|
||||
)
|
||||
if row is None or row.param_value is None:
|
||||
@@ -14148,7 +14167,7 @@ async def update_config( # noqa: PLR0915
|
||||
|
||||
async def _upsert_section(param_name: str, value: dict) -> None:
|
||||
serialized = json.dumps(value)
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": param_name},
|
||||
data={
|
||||
"create": {"param_name": param_name, "param_value": serialized},
|
||||
@@ -14323,7 +14342,7 @@ async def update_config_general_settings(
|
||||
)
|
||||
|
||||
## get general settings from db
|
||||
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
||||
db_general_settings = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
### update value
|
||||
@@ -14337,7 +14356,7 @@ async def update_config_general_settings(
|
||||
|
||||
general_settings[data.field_name] = data.field_value
|
||||
|
||||
response = await prisma_client.db.litellm_config.upsert(
|
||||
response = await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "general_settings"},
|
||||
data={
|
||||
"create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
|
||||
@@ -14387,7 +14406,7 @@ async def get_config_general_settings(
|
||||
)
|
||||
|
||||
## get general settings from db
|
||||
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
||||
db_general_settings = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
### pop the value
|
||||
@@ -14450,7 +14469,7 @@ async def get_config_list(
|
||||
)
|
||||
|
||||
## get general settings from db
|
||||
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
||||
db_general_settings = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
|
||||
@@ -14604,7 +14623,7 @@ async def delete_config_general_settings(
|
||||
)
|
||||
|
||||
## get general settings from db
|
||||
db_general_settings = await prisma_client.db.litellm_config.find_first(
|
||||
db_general_settings = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": "general_settings"}
|
||||
)
|
||||
### pop the value
|
||||
@@ -14621,7 +14640,7 @@ async def delete_config_general_settings(
|
||||
|
||||
general_settings.pop(data.field_name, None)
|
||||
|
||||
response = await prisma_client.db.litellm_config.upsert(
|
||||
response = await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "general_settings"},
|
||||
data={
|
||||
"create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
|
||||
@@ -14975,14 +14994,14 @@ async def reload_model_cost_map(
|
||||
last_model_cost_map_reload = current_time.isoformat()
|
||||
|
||||
# Set force reload flag in database for other pods, preserving existing interval_hours
|
||||
existing_config = await prisma_client.db.litellm_config.find_unique(
|
||||
existing_config = await ConfigRepository(prisma_client).table.find_unique(
|
||||
where={"param_name": "model_cost_map_reload_config"}
|
||||
)
|
||||
existing_interval = None
|
||||
if existing_config and existing_config.param_value:
|
||||
existing_interval = existing_config.param_value.get("interval_hours")
|
||||
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "model_cost_map_reload_config"},
|
||||
data={
|
||||
"create": {
|
||||
@@ -15052,7 +15071,7 @@ async def schedule_model_cost_map_reload(
|
||||
)
|
||||
|
||||
# Update database with new reload configuration
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "model_cost_map_reload_config"},
|
||||
data={
|
||||
"create": {
|
||||
@@ -15119,7 +15138,7 @@ async def cancel_model_cost_map_reload(
|
||||
)
|
||||
|
||||
# Remove reload configuration from database
|
||||
await prisma_client.db.litellm_config.delete(
|
||||
await ConfigRepository(prisma_client).table.delete(
|
||||
where={"param_name": "model_cost_map_reload_config"}
|
||||
)
|
||||
await invalidate_config_param("model_cost_map_reload_config")
|
||||
@@ -15178,7 +15197,7 @@ async def get_model_cost_map_reload_status(
|
||||
}
|
||||
|
||||
# Get reload configuration from database
|
||||
config_record = await prisma_client.db.litellm_config.find_unique(
|
||||
config_record = await ConfigRepository(prisma_client).table.find_unique(
|
||||
where={"param_name": "model_cost_map_reload_config"}
|
||||
)
|
||||
|
||||
@@ -15329,7 +15348,7 @@ async def reload_anthropic_beta_headers(
|
||||
last_anthropic_beta_headers_reload = current_time.isoformat()
|
||||
|
||||
# Set force reload flag in database for other pods, preserving existing interval_hours
|
||||
existing_beta_config = await prisma_client.db.litellm_config.find_unique(
|
||||
existing_beta_config = await ConfigRepository(prisma_client).table.find_unique(
|
||||
where={"param_name": "anthropic_beta_headers_reload_config"}
|
||||
)
|
||||
existing_beta_interval = None
|
||||
@@ -15338,7 +15357,7 @@ async def reload_anthropic_beta_headers(
|
||||
"interval_hours"
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "anthropic_beta_headers_reload_config"},
|
||||
data={
|
||||
"create": {
|
||||
@@ -15412,7 +15431,7 @@ async def schedule_anthropic_beta_headers_reload(
|
||||
)
|
||||
|
||||
# Update database with new reload configuration
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "anthropic_beta_headers_reload_config"},
|
||||
data={
|
||||
"create": {
|
||||
@@ -15479,7 +15498,7 @@ async def cancel_anthropic_beta_headers_reload(
|
||||
)
|
||||
|
||||
# Remove reload configuration from database
|
||||
await prisma_client.db.litellm_config.delete(
|
||||
await ConfigRepository(prisma_client).table.delete(
|
||||
where={"param_name": "anthropic_beta_headers_reload_config"}
|
||||
)
|
||||
await invalidate_config_param("anthropic_beta_headers_reload_config")
|
||||
@@ -15539,7 +15558,7 @@ async def get_anthropic_beta_headers_reload_status(
|
||||
}
|
||||
|
||||
# Get reload configuration from database
|
||||
config_record = await prisma_client.db.litellm_config.find_unique(
|
||||
config_record = await ConfigRepository(prisma_client).table.find_unique(
|
||||
where={"param_name": "anthropic_beta_headers_reload_config"}
|
||||
)
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ import re
|
||||
from importlib.resources import files
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.get_blog_posts import (
|
||||
BlogPost,
|
||||
@@ -17,6 +17,7 @@ from litellm.litellm_core_utils.get_blog_posts import (
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
)
|
||||
from litellm.repositories.table_repositories import ClaudeCodePluginRepository
|
||||
from litellm.types.agents import AgentCard
|
||||
from litellm.types.mcp import MCPPublicServer
|
||||
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
|
||||
@@ -159,14 +160,14 @@ def _load_endpoints() -> List[Dict[str, Any]]:
|
||||
)
|
||||
async def public_model_hub():
|
||||
import litellm
|
||||
from litellm.proxy.health_endpoints._health_endpoints import (
|
||||
_convert_health_check_to_dict,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
_get_model_group_info,
|
||||
llm_router,
|
||||
prisma_client,
|
||||
)
|
||||
from litellm.proxy.health_endpoints._health_endpoints import (
|
||||
_convert_health_check_to_dict,
|
||||
)
|
||||
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
@@ -266,7 +267,7 @@ async def public_skill_hub():
|
||||
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many(
|
||||
plugins = await ClaudeCodePluginRepository(prisma_client).table.find_many(
|
||||
where={"enabled": True}
|
||||
)
|
||||
items = []
|
||||
|
||||
@@ -17,16 +17,17 @@ import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.auth_utils import is_request_body_safe
|
||||
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth
|
||||
from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
_read_request_body,
|
||||
_safe_get_request_headers,
|
||||
get_form_data,
|
||||
)
|
||||
from litellm.proxy.auth.auth_utils import is_request_body_safe
|
||||
from litellm.proxy.vector_store_endpoints.utils import (
|
||||
assert_user_can_access_vector_store_id,
|
||||
)
|
||||
from litellm.repositories.table_repositories import ManagedVectorStoresRepository
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -230,11 +231,9 @@ async def _save_vector_store_to_db_from_rag_ingest(
|
||||
|
||||
try:
|
||||
# Check if vector store already exists in database
|
||||
existing_vector_store = (
|
||||
await prisma_client.db.litellm_managedvectorstorestable.find_unique(
|
||||
where={"vector_store_id": vector_store_id}
|
||||
)
|
||||
)
|
||||
existing_vector_store = await ManagedVectorStoresRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"vector_store_id": vector_store_id})
|
||||
|
||||
# Only create if it doesn't exist
|
||||
if existing_vector_store is None:
|
||||
@@ -289,7 +288,7 @@ async def _save_vector_store_to_db_from_rag_ingest(
|
||||
# Update the vector store
|
||||
from litellm.proxy.utils import safe_dumps
|
||||
|
||||
await prisma_client.db.litellm_managedvectorstorestable.update(
|
||||
await ManagedVectorStoresRepository(prisma_client).table.update(
|
||||
where={"vector_store_id": vector_store_id},
|
||||
data={"vector_store_metadata": safe_dumps(existing_metadata)},
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import List, Optional
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.table_repositories import SearchToolsRepository
|
||||
from litellm.types.search import SearchTool
|
||||
|
||||
|
||||
@@ -63,16 +64,16 @@ class SearchToolRegistry:
|
||||
search_tool_info: str = safe_dumps(search_tool.get("search_tool_info", {}))
|
||||
|
||||
# Create search tool in DB
|
||||
created_search_tool = (
|
||||
await prisma_client.db.litellm_searchtoolstable.create(
|
||||
data={
|
||||
"search_tool_name": search_tool_name,
|
||||
"litellm_params": litellm_params,
|
||||
"search_tool_info": search_tool_info,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
)
|
||||
created_search_tool = await SearchToolsRepository(
|
||||
prisma_client
|
||||
).table.create(
|
||||
data={
|
||||
"search_tool_name": search_tool_name,
|
||||
"litellm_params": litellm_params,
|
||||
"search_tool_info": search_tool_info,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
# Add search_tool_id to the returned search tool object
|
||||
@@ -101,15 +102,15 @@ class SearchToolRegistry:
|
||||
"""
|
||||
try:
|
||||
# Get search tool before deletion for response
|
||||
existing_tool = await prisma_client.db.litellm_searchtoolstable.find_unique(
|
||||
where={"search_tool_id": search_tool_id}
|
||||
)
|
||||
existing_tool = await SearchToolsRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"search_tool_id": search_tool_id})
|
||||
|
||||
if not existing_tool:
|
||||
raise Exception(f"Search tool with ID {search_tool_id} not found")
|
||||
|
||||
# Delete from DB
|
||||
await prisma_client.db.litellm_searchtoolstable.delete(
|
||||
await SearchToolsRepository(prisma_client).table.delete(
|
||||
where={"search_tool_id": search_tool_id}
|
||||
)
|
||||
|
||||
@@ -145,16 +146,16 @@ class SearchToolRegistry:
|
||||
search_tool_info: str = safe_dumps(search_tool.get("search_tool_info", {}))
|
||||
|
||||
# Update in DB
|
||||
updated_search_tool = (
|
||||
await prisma_client.db.litellm_searchtoolstable.update(
|
||||
where={"search_tool_id": search_tool_id},
|
||||
data={
|
||||
"search_tool_name": search_tool_name,
|
||||
"litellm_params": litellm_params,
|
||||
"search_tool_info": search_tool_info,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
updated_search_tool = await SearchToolsRepository(
|
||||
prisma_client
|
||||
).table.update(
|
||||
where={"search_tool_id": search_tool_id},
|
||||
data={
|
||||
"search_tool_name": search_tool_name,
|
||||
"litellm_params": litellm_params,
|
||||
"search_tool_info": search_tool_info,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
# Convert to dict with ISO formatted datetimes
|
||||
@@ -179,10 +180,10 @@ class SearchToolRegistry:
|
||||
List of search tool configurations
|
||||
"""
|
||||
try:
|
||||
search_tools_from_db = (
|
||||
await prisma_client.db.litellm_searchtoolstable.find_many(
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
search_tools_from_db = await SearchToolsRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
|
||||
search_tools: List[SearchTool] = []
|
||||
@@ -214,7 +215,7 @@ class SearchToolRegistry:
|
||||
Search tool configuration or None if not found
|
||||
"""
|
||||
try:
|
||||
search_tool = await prisma_client.db.litellm_searchtoolstable.find_unique(
|
||||
search_tool = await SearchToolsRepository(prisma_client).table.find_unique(
|
||||
where={"search_tool_id": search_tool_id}
|
||||
)
|
||||
|
||||
@@ -244,7 +245,7 @@ class SearchToolRegistry:
|
||||
Search tool configuration or None if not found
|
||||
"""
|
||||
try:
|
||||
search_tool = await prisma_client.db.litellm_searchtoolstable.find_unique(
|
||||
search_tool = await SearchToolsRepository(prisma_client).table.find_unique(
|
||||
where={"search_tool_name": search_tool_name}
|
||||
)
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@ from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
|
||||
from litellm.repositories.config_repository import ConfigRepository
|
||||
from litellm.types.proxy.cloudzero_endpoints import (
|
||||
CloudZeroExportRequest,
|
||||
CloudZeroExportResponse,
|
||||
@@ -53,7 +54,7 @@ async def _set_cloudzero_settings(api_key: str, connection_id: str, timezone: st
|
||||
"timezone": timezone,
|
||||
}
|
||||
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": "cloudzero_settings"},
|
||||
data={
|
||||
"create": {
|
||||
@@ -80,7 +81,7 @@ async def _get_cloudzero_settings():
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
cloudzero_config = await prisma_client.db.litellm_config.find_first(
|
||||
cloudzero_config = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": "cloudzero_settings"}
|
||||
)
|
||||
if cloudzero_config is None or cloudzero_config.param_value is None:
|
||||
@@ -282,7 +283,7 @@ async def is_cloudzero_setup_in_db() -> bool:
|
||||
return False
|
||||
|
||||
# Check for CloudZero settings in database
|
||||
cloudzero_config = await prisma_client.db.litellm_config.find_first(
|
||||
cloudzero_config = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": "cloudzero_settings"}
|
||||
)
|
||||
|
||||
@@ -548,7 +549,7 @@ async def delete_cloudzero_settings(
|
||||
)
|
||||
|
||||
# Check if CloudZero settings exist
|
||||
cloudzero_config = await prisma_client.db.litellm_config.find_first(
|
||||
cloudzero_config = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": "cloudzero_settings"}
|
||||
)
|
||||
|
||||
@@ -560,7 +561,7 @@ async def delete_cloudzero_settings(
|
||||
|
||||
# Delete only the CloudZero settings entry
|
||||
# This uses a specific where clause to target only the cloudzero_settings row
|
||||
await prisma_client.db.litellm_config.delete(
|
||||
await ConfigRepository(prisma_client).table.delete(
|
||||
where={"param_name": "cloudzero_settings"}
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +21,11 @@ from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||
get_spend_by_team_and_customer,
|
||||
)
|
||||
from litellm.proxy.utils import handle_exception_on_proxy
|
||||
from litellm.repositories.table_repositories import SpendLogsRepository
|
||||
from litellm.repositories.team_repository import TeamRepository
|
||||
from litellm.repositories.verification_token_repository import (
|
||||
VerificationTokenRepository,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.proxy_server import PrismaClient
|
||||
@@ -2010,7 +2015,7 @@ async def ui_view_spend_logs( # noqa: PLR0915
|
||||
order_direction = (sort_order or "desc").lower()
|
||||
|
||||
# Get total count of records
|
||||
total_records = await prisma_client.db.litellm_spendlogs.count(
|
||||
total_records = await SpendLogsRepository(prisma_client).table.count(
|
||||
where=where_conditions,
|
||||
)
|
||||
|
||||
@@ -2374,7 +2379,7 @@ async def view_spend_logs( # noqa: PLR0915
|
||||
# Check if user wants unsummarized data
|
||||
if not summarize:
|
||||
# Return filtered individual log entries (similar to UI endpoint)
|
||||
data = await prisma_client.db.litellm_spendlogs.find_many(
|
||||
data = await SpendLogsRepository(prisma_client).table.find_many(
|
||||
where=filter_query, # type: ignore
|
||||
order={
|
||||
"startTime": "desc",
|
||||
@@ -2384,7 +2389,7 @@ async def view_spend_logs( # noqa: PLR0915
|
||||
|
||||
# Legacy behavior: return summarized data (when summarize=true)
|
||||
# SQL query
|
||||
response = await prisma_client.db.litellm_spendlogs.group_by(
|
||||
response = await SpendLogsRepository(prisma_client).table.group_by(
|
||||
by=["api_key", "user", "model", "startTime"],
|
||||
where=filter_query, # type: ignore
|
||||
sum={
|
||||
@@ -2462,7 +2467,7 @@ async def view_spend_logs( # noqa: PLR0915
|
||||
)
|
||||
return spend_logs
|
||||
|
||||
data = await prisma_client.db.litellm_spendlogs.find_many(
|
||||
data = await SpendLogsRepository(prisma_client).table.find_many(
|
||||
where=scoped_filter, # type: ignore
|
||||
order={"startTime": "desc"},
|
||||
)
|
||||
@@ -2514,10 +2519,10 @@ async def global_spend_reset():
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_verificationtoken.update_many(
|
||||
await VerificationTokenRepository(prisma_client).table.update_many(
|
||||
data={"spend": 0.0}, where={}
|
||||
)
|
||||
await prisma_client.db.litellm_teamtable.update_many(data={"spend": 0.0}, where={})
|
||||
await TeamRepository(prisma_client).table.update_many(data={"spend": 0.0}, where={})
|
||||
|
||||
return {
|
||||
"message": "Spend for all API Keys and Teams reset successfully",
|
||||
@@ -3384,7 +3389,7 @@ async def ui_view_session_spend_logs(
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
# Get total count for pagination metadata
|
||||
total_records = await prisma_client.db.litellm_spendlogs.count(
|
||||
total_records = await SpendLogsRepository(prisma_client).table.count(
|
||||
where=where_conditions
|
||||
)
|
||||
|
||||
@@ -3485,7 +3490,7 @@ async def _build_ui_spend_logs_response(
|
||||
# is bounded by page_size (typically 25-50 distinct session IDs).
|
||||
# If performance degrades at scale, consider short-lived caching or
|
||||
# folding the count into the main query via a window function.
|
||||
counts = await prisma_client.db.litellm_spendlogs.group_by(
|
||||
counts = await SpendLogsRepository(prisma_client).table.group_by(
|
||||
by=["session_id"],
|
||||
where={"session_id": {"in": session_ids}},
|
||||
count={"session_id": True},
|
||||
@@ -3572,7 +3577,7 @@ async def _can_team_member_view_log(
|
||||
|
||||
if team_id is None:
|
||||
return False
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
team_row = await TeamRepository(prisma_client).table.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
if team_row is None:
|
||||
@@ -3614,7 +3619,7 @@ async def _assert_user_can_view_request_id(
|
||||
permitted teams (admin or ``/spend/logs`` permission).
|
||||
Raises HTTP 403 if not.
|
||||
"""
|
||||
row = await prisma_client.db.litellm_spendlogs.find_unique(
|
||||
row = await SpendLogsRepository(prisma_client).table.find_unique(
|
||||
where={"request_id": request_id},
|
||||
include=None,
|
||||
)
|
||||
@@ -3669,7 +3674,7 @@ async def _get_permitted_team_ids_for_spend_logs(
|
||||
if user_obj is None or not user_obj.teams:
|
||||
return []
|
||||
|
||||
team_rows = await prisma_client.db.litellm_teamtable.find_many(
|
||||
team_rows = await TeamRepository(prisma_client).table.find_many(
|
||||
where={"team_id": {"in": user_obj.teams}}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import json
|
||||
|
||||
import litellm
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
|
||||
from litellm.repositories.config_repository import ConfigRepository
|
||||
from litellm.types.proxy.vantage_endpoints import (
|
||||
VantageDryRunRequest,
|
||||
VantageExportRequest,
|
||||
@@ -60,7 +61,7 @@ async def _set_vantage_settings(api_key: str, integration_token: str, base_url:
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
await ConfigRepository(prisma_client).table.upsert(
|
||||
where={"param_name": VANTAGE_SETTINGS_PARAM_NAME},
|
||||
data={
|
||||
"create": {
|
||||
@@ -82,7 +83,7 @@ async def _get_vantage_settings():
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
vantage_config = await prisma_client.db.litellm_config.find_first(
|
||||
vantage_config = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": VANTAGE_SETTINGS_PARAM_NAME}
|
||||
)
|
||||
if vantage_config is None or vantage_config.param_value is None:
|
||||
@@ -265,7 +266,7 @@ async def is_vantage_setup_in_db() -> bool:
|
||||
if prisma_client is None:
|
||||
return False
|
||||
|
||||
vantage_config = await prisma_client.db.litellm_config.find_first(
|
||||
vantage_config = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": VANTAGE_SETTINGS_PARAM_NAME}
|
||||
)
|
||||
|
||||
@@ -553,7 +554,7 @@ async def delete_vantage_settings(
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
vantage_config = await prisma_client.db.litellm_config.find_first(
|
||||
vantage_config = await ConfigRepository(prisma_client).table.find_first(
|
||||
where={"param_name": VANTAGE_SETTINGS_PARAM_NAME}
|
||||
)
|
||||
|
||||
@@ -563,7 +564,7 @@ async def delete_vantage_settings(
|
||||
detail={"error": "Vantage settings not found"},
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_config.delete(
|
||||
await ConfigRepository(prisma_client).table.delete(
|
||||
where={"param_name": VANTAGE_SETTINGS_PARAM_NAME}
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,12 @@ from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import mask_sensitive_keys
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.repositories.config_repository import ConfigRepository
|
||||
from litellm.repositories.table_repositories import (
|
||||
DailyTagSpendRepository,
|
||||
SSOConfigRepository,
|
||||
UISettingsRepository,
|
||||
)
|
||||
from litellm.types.proxy.management_endpoints.ui_sso import (
|
||||
DefaultTeamSSOParams,
|
||||
InProductNudgeResponse,
|
||||
@@ -665,7 +671,7 @@ async def get_sso_settings():
|
||||
)
|
||||
|
||||
# Get SSO config from dedicated table
|
||||
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
|
||||
sso_db_record = await SSOConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "sso_config"}
|
||||
)
|
||||
|
||||
@@ -836,7 +842,7 @@ async def update_sso_settings(sso_config: SSOConfig):
|
||||
)
|
||||
|
||||
# Save to dedicated SSO table
|
||||
await prisma_client.db.litellm_ssoconfig.upsert(
|
||||
await SSOConfigRepository(prisma_client).table.upsert(
|
||||
where={"id": "sso_config"},
|
||||
data={
|
||||
"create": {
|
||||
@@ -851,7 +857,7 @@ async def update_sso_settings(sso_config: SSOConfig):
|
||||
|
||||
# Remove SSO-related env vars from config.environment_variables
|
||||
try:
|
||||
env_var_entry = await prisma_client.db.litellm_config.find_unique(
|
||||
env_var_entry = await ConfigRepository(prisma_client).table.find_unique(
|
||||
where={"param_name": "environment_variables"}
|
||||
)
|
||||
|
||||
@@ -872,7 +878,7 @@ async def update_sso_settings(sso_config: SSOConfig):
|
||||
if key not in env_vars_to_remove
|
||||
}
|
||||
|
||||
await prisma_client.db.litellm_config.update(
|
||||
await ConfigRepository(prisma_client).table.update(
|
||||
where={"param_name": "environment_variables"},
|
||||
data={
|
||||
"param_value": json.dumps(filtered_env_vars, default=str),
|
||||
@@ -1123,7 +1129,7 @@ async def get_in_product_nudges():
|
||||
detail={"error": "Database not connected. Please connect a database."},
|
||||
)
|
||||
|
||||
db_record = await prisma_client.db.litellm_dailytagspend.find_first(
|
||||
db_record = await DailyTagSpendRepository(prisma_client).table.find_first(
|
||||
where={"tag": "User-Agent: claude-cli"}
|
||||
)
|
||||
|
||||
@@ -1155,7 +1161,7 @@ async def get_ui_settings_cached() -> Dict[str, Any]:
|
||||
if prisma_client is None:
|
||||
return {}
|
||||
|
||||
db_record = await prisma_client.db.litellm_uisettings.find_unique(
|
||||
db_record = await UISettingsRepository(prisma_client).table.find_unique(
|
||||
where={"id": "ui_settings"}
|
||||
)
|
||||
ui_settings: Dict[str, Any] = {}
|
||||
@@ -1196,7 +1202,7 @@ async def get_ui_settings():
|
||||
|
||||
ui_settings: Dict[str, Any] = {}
|
||||
|
||||
db_record = await prisma_client.db.litellm_uisettings.find_unique(
|
||||
db_record = await UISettingsRepository(prisma_client).table.find_unique(
|
||||
where={"id": "ui_settings"}
|
||||
)
|
||||
|
||||
@@ -1309,7 +1315,7 @@ async def update_ui_settings(
|
||||
# Merge with existing persisted settings so a partial PATCH doesn't
|
||||
# overwrite fields the caller didn't send.
|
||||
existing: dict = {}
|
||||
db_existing = await prisma_client.db.litellm_uisettings.find_unique(
|
||||
db_existing = await UISettingsRepository(prisma_client).table.find_unique(
|
||||
where={"id": "ui_settings"}
|
||||
)
|
||||
if db_existing and db_existing.ui_settings:
|
||||
@@ -1318,7 +1324,7 @@ async def update_ui_settings(
|
||||
|
||||
ui_settings = {**existing, **incoming}
|
||||
|
||||
await prisma_client.db.litellm_uisettings.upsert(
|
||||
await UISettingsRepository(prisma_client).table.upsert(
|
||||
where={"id": "ui_settings"},
|
||||
data={
|
||||
"create": {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user