From 5e2db7eee4e30e1c5b698e52ef4cc411860a1ca5 Mon Sep 17 00:00:00 2001 From: Yassin Kortam Date: Sat, 6 Jun 2026 20:59:33 -0700 Subject: [PATCH] feat(litellm): add models and repository layers (#29686) --- .github/workflows/test-unit-misc.yml | 2 + ARCHITECTURE.md | 18 + .../SlackAlerting/slack_alerting.py | 6 +- litellm/integrations/email_alerting.py | 3 +- litellm/integrations/prometheus.py | 21 +- litellm/llms/litellm_proxy/skills/handler.py | 9 +- litellm/models/__init__.py | 66 + litellm/models/access_group.py | 26 + litellm/models/base.py | 38 + litellm/models/budget.py | 56 + litellm/models/config.py | 15 + litellm/models/credentials.py | 31 + litellm/models/end_user.py | 35 + litellm/models/managed_files.py | 62 + litellm/models/mcp_server.py | 103 + litellm/models/model.py | 59 + litellm/models/object_permission.py | 26 + litellm/models/organization.py | 31 + litellm/models/organization_membership.py | 40 + litellm/models/project.py | 41 + litellm/models/skills.py | 30 + litellm/models/spend_logs.py | 50 + litellm/models/tag.py | 36 + litellm/models/team.py | 154 ++ litellm/models/team_membership.py | 32 + litellm/models/user.py | 70 + litellm/models/verification_token.py | 74 + .../mcp_server/auth/user_api_key_auth_mcp.py | 10 +- litellm/proxy/_experimental/mcp_server/db.py | 125 +- .../mcp_server/mcp_server_manager.py | 5 +- .../_experimental/mcp_server/toolset_db.py | 13 +- litellm/proxy/_types.py | 755 +----- .../proxy/agent_endpoints/agent_registry.py | 19 +- .../auth/agent_permission_handler.py | 5 +- litellm/proxy/agent_endpoints/endpoints.py | 23 +- .../claude_code_marketplace.py | 25 +- litellm/proxy/auth/auth_checks.py | 74 +- litellm/proxy/auth/handle_jwt.py | 5 +- litellm/proxy/auth/login_utils.py | 5 +- litellm/proxy/auth/model_checks.py | 3 +- litellm/proxy/auth/user_api_key_auth.py | 9 +- .../expired_ui_session_key_cleanup_manager.py | 7 +- .../common_utils/key_rotation_manager.py | 42 +- .../proxy/common_utils/reset_budget_job.py | 24 +- .../proxy/container_endpoints/ownership.py | 9 +- .../proxy/credential_endpoints/endpoints.py | 16 +- litellm/proxy/db/spend_counter_reseed.py | 26 +- litellm/proxy/db/spend_log_tool_index.py | 3 +- litellm/proxy/db/tool_registry_writer.py | 26 +- .../proxy/guardrails/guardrail_endpoints.py | 26 +- .../proxy/guardrails/guardrail_registry.py | 31 +- litellm/proxy/guardrails/usage_endpoints.py | 44 +- litellm/proxy/guardrails/usage_tracking.py | 8 +- .../hooks/user_management_event_hooks.py | 5 +- .../access_group_endpoints.py | 5 +- .../budget_management_endpoints.py | 15 +- .../cache_settings_endpoints.py | 9 +- .../common_daily_activity.py | 16 +- .../management_endpoints/common_utils.py | 9 +- .../config_override_endpoints.py | 11 +- .../customer_endpoints.py | 56 +- .../fallback_management_endpoints.py | 5 +- .../internal_user_endpoints.py | 95 +- .../jwt_key_mapping_endpoints.py | 17 +- .../key_management_endpoints.py | 151 +- .../mcp_management_endpoints.py | 14 +- ...model_access_group_management_endpoints.py | 19 +- .../model_management_endpoints.py | 33 +- .../organization_endpoints.py | 207 +- .../scim/scim_transformations.py | 3 +- .../management_endpoints/scim/scim_v2.py | 86 +- .../tag_management_endpoints.py | 34 +- .../team_callback_endpoints.py | 5 +- .../management_endpoints/team_endpoints.py | 212 +- .../tool_management_endpoints.py | 37 +- litellm/proxy/management_endpoints/ui_sso.py | 29 +- .../user_agent_analytics_endpoints.py | 15 +- .../workflow_management_endpoints.py | 27 +- .../proxy/management_helpers/audit_logs.py | 3 +- .../object_permission_utils.py | 38 +- .../management_helpers/user_invitation.py | 3 +- litellm/proxy/management_helpers/utils.py | 37 +- litellm/proxy/memory/memory_endpoints.py | 20 +- .../openai_files_endpoints/common_utils.py | 15 +- .../openai_files_endpoints/files_endpoints.py | 21 +- .../managed_id_rewriter.py | 30 +- .../pass_through_endpoints.py | 9 +- .../policy_engine/attachment_registry.py | 57 +- .../proxy/policy_engine/policy_registry.py | 43 +- .../policy_engine/policy_resolve_endpoints.py | 12 +- .../proxy/policy_engine/policy_validator.py | 10 +- litellm/proxy/prompts/prompt_endpoints.py | 21 +- litellm/proxy/proxy_server.py | 165 +- .../public_endpoints/public_endpoints.py | 11 +- litellm/proxy/rag_endpoints/endpoints.py | 13 +- .../search_endpoints/search_tool_registry.py | 61 +- .../spend_tracking/cloudzero_endpoints.py | 13 +- .../spend_management_endpoints.py | 27 +- .../proxy/spend_tracking/vantage_endpoints.py | 15 +- .../proxy_setting_endpoints.py | 24 +- litellm/proxy/utils.py | 123 +- .../proxy/vector_store_endpoints/endpoints.py | 12 +- .../management_endpoints.py | 28 +- litellm/repositories/__init__.py | 127 + litellm/repositories/base_repository.py | 117 + litellm/repositories/budget_repository.py | 99 + litellm/repositories/config_repository.py | 241 ++ .../repositories/credentials_repository.py | 61 + litellm/repositories/model_repository.py | 171 ++ .../object_permission_repository.py | 110 + .../repositories/organization_repository.py | 103 + litellm/repositories/project_repository.py | 129 + litellm/repositories/table_repositories.py | 215 ++ litellm/repositories/team_repository.py | 351 +++ litellm/repositories/user_repository.py | 229 ++ .../verification_token_repository.py | 375 +++ .../adaptive_router/adaptive_router.py | 3 +- .../adaptive_router/update_queue.py | 8 +- .../types/mcp_server/mcp_server_manager.py | 3 +- litellm/types/utils.py | 25 +- .../vector_stores/vector_store_registry.py | 26 +- tests/test_litellm/models/test_models.py | 542 ++++ .../repositories/test_repositories.py | 2184 +++++++++++++++++ ui/litellm-dashboard/src/lib/http/schema.d.ts | 179 +- 124 files changed, 7846 insertions(+), 1850 deletions(-) create mode 100644 litellm/models/__init__.py create mode 100644 litellm/models/access_group.py create mode 100644 litellm/models/base.py create mode 100644 litellm/models/budget.py create mode 100644 litellm/models/config.py create mode 100644 litellm/models/credentials.py create mode 100644 litellm/models/end_user.py create mode 100644 litellm/models/managed_files.py create mode 100644 litellm/models/mcp_server.py create mode 100644 litellm/models/model.py create mode 100644 litellm/models/object_permission.py create mode 100644 litellm/models/organization.py create mode 100644 litellm/models/organization_membership.py create mode 100644 litellm/models/project.py create mode 100644 litellm/models/skills.py create mode 100644 litellm/models/spend_logs.py create mode 100644 litellm/models/tag.py create mode 100644 litellm/models/team.py create mode 100644 litellm/models/team_membership.py create mode 100644 litellm/models/user.py create mode 100644 litellm/models/verification_token.py create mode 100644 litellm/repositories/__init__.py create mode 100644 litellm/repositories/base_repository.py create mode 100644 litellm/repositories/budget_repository.py create mode 100644 litellm/repositories/config_repository.py create mode 100644 litellm/repositories/credentials_repository.py create mode 100644 litellm/repositories/model_repository.py create mode 100644 litellm/repositories/object_permission_repository.py create mode 100644 litellm/repositories/organization_repository.py create mode 100644 litellm/repositories/project_repository.py create mode 100644 litellm/repositories/table_repositories.py create mode 100644 litellm/repositories/team_repository.py create mode 100644 litellm/repositories/user_repository.py create mode 100644 litellm/repositories/verification_token_repository.py create mode 100644 tests/test_litellm/models/test_models.py create mode 100644 tests/test_litellm/repositories/test_repositories.py diff --git a/.github/workflows/test-unit-misc.yml b/.github/workflows/test-unit-misc.yml index 9add77ff42..a7363ac3b4 100644 --- a/.github/workflows/test-unit-misc.yml +++ b/.github/workflows/test-unit-misc.yml @@ -28,6 +28,8 @@ jobs: tests/test_litellm/completion_extras tests/test_litellm/containers tests/test_litellm/experimental_mcp_client + tests/test_litellm/models + tests/test_litellm/repositories tests/test_litellm/images tests/test_litellm/interactions tests/test_litellm/passthrough diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index c114a838d6..3d2fa3e51c 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -240,6 +240,24 @@ graph LR 7. `DBSpendUpdateWriter.update_database()` queues spend increments to Redis 8. Background job `update_spend` flushes queued spend to PostgreSQL every 60s +### Data Access Layer (Models & Repositories) + +Database entities and the operations on them live in two packages at the root of `litellm/` so both the gateway (`proxy/`) and the SDK can use them without importing proxy internals: + +- `litellm/models/` holds the canonical Pydantic definitions for every persisted entity (`LiteLLM_VerificationToken`, `LiteLLM_TeamTable`, `LiteLLM_UserTable`, etc.). `proxy/_types.py` re-exports these for backwards compatibility, so existing imports keep working. +- `litellm/repositories/` holds the data-access layer. `BaseRepository[T]` provides the generic CRUD (`find_by_id`, `find_many`, `create`, `update`, `delete`, `count`, `exists`); entity repositories such as `VerificationTokenRepository`, `TeamRepository`, and `UserRepository` add domain-specific queries and writes on top of it. + +Conventions to follow when touching this layer: + +| Concern | How it's handled | +|---------|------------------| +| JSON columns | Prisma `Json` columns are stored as JSON strings. Repositories `json.dumps()` on write and `json.loads()` on read (see `_to_model` and the `_build_*_data` helpers). | +| Archive-then-delete | `delete_team` / `delete_token` copy the row into the `LiteLLM_Deleted*` table and delete the original inside a single `prisma_client.db.tx()` transaction. Archive payloads are built explicitly so only columns that exist on the archive table are written. | +| Column vs. field names | Where a model field differs from its DB column (for example `org_id` maps to the `organization_id` column), the repository translates in both directions rather than relying on Pydantic to guess. | +| Array mutations | Adds use Prisma's atomic `push` (`add_member`, `add_admin`, `add_models`) to avoid read-modify-write races. Removals fall back to read-modify-write because Prisma has no atomic array remove. | + +To add a new entity, define the model under `litellm/models/`, re-export it from `proxy/_types.py` if existing code imports it from there, and add a repository under `litellm/repositories/` (subclass `BaseRepository` for plain CRUD, or add bespoke methods when the entity needs encryption, archiving, or atomic array updates). Mirror the tests in `tests/test_litellm/repositories/`. + --- ## 2. SDK Request Flow diff --git a/litellm/integrations/SlackAlerting/slack_alerting.py b/litellm/integrations/SlackAlerting/slack_alerting.py index 0ec17bbea5..390af2cb6e 100644 --- a/litellm/integrations/SlackAlerting/slack_alerting.py +++ b/litellm/integrations/SlackAlerting/slack_alerting.py @@ -37,6 +37,8 @@ from litellm.proxy._types import ( VirtualKeyEvent, WebhookEvent, ) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository from litellm.types.integrations.slack_alerting import * from ..email_templates.templates import * @@ -1231,7 +1233,7 @@ Model Info: and recipient_user_id is not None and prisma_client is not None ): - user_row = await prisma_client.db.litellm_usertable.find_unique( + user_row = await UserRepository(prisma_client).table.find_unique( where={"user_id": recipient_user_id} ) @@ -1263,7 +1265,7 @@ Model Info: team_id = webhook_event.team_id team_name = "Default Team" if team_id is not None and prisma_client is not None: - team_row = await prisma_client.db.litellm_teamtable.find_unique( + team_row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) if team_row is not None: diff --git a/litellm/integrations/email_alerting.py b/litellm/integrations/email_alerting.py index b45b9aa7f5..b721dc5046 100644 --- a/litellm/integrations/email_alerting.py +++ b/litellm/integrations/email_alerting.py @@ -7,6 +7,7 @@ from typing import List, Optional from litellm._logging import verbose_logger, verbose_proxy_logger from litellm.proxy._types import WebhookEvent +from litellm.repositories.team_repository import TeamRepository # we use this for the email header, please send a test email if you change this. verify it looks good on email LITELLM_LOGO_URL = "https://litellm-listing.s3.amazonaws.com/litellm_logo.png" @@ -24,7 +25,7 @@ async def get_all_team_member_emails(team_id: Optional[str] = None) -> list: if prisma_client is None: raise Exception("Not connected to DB!") - team_row = await prisma_client.db.litellm_teamtable.find_unique( + team_row = await TeamRepository(prisma_client).table.find_unique( where={ "team_id": team_id, } diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index d2af95cd4c..2119527a8e 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -29,13 +29,13 @@ from litellm.exceptions import ( validate_rate_limit_type, ) from litellm.integrations.custom_logger import CustomLogger -from litellm.integrations.prometheus_helpers.bounded_prometheus_series_tracker import ( - BoundedPrometheusSeriesTracker, -) from litellm.integrations.prometheus_helpers import ( PrometheusLabelFactoryContext, _get_cached_end_user_id_for_cost_tracking, ) +from litellm.integrations.prometheus_helpers.bounded_prometheus_series_tracker import ( + BoundedPrometheusSeriesTracker, +) from litellm.litellm_core_utils.core_helpers import ( get_litellm_metadata_from_kwargs, get_metadata_variable_name_from_kwargs, @@ -46,6 +46,9 @@ from litellm.proxy._types import ( LiteLLM_UserTable, UserAPIKeyAuth, ) +from litellm.repositories.organization_repository import OrganizationRepository +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository from litellm.types.integrations.prometheus import * from litellm.types.integrations.prometheus import ( _sanitize_prometheus_label_name, @@ -3278,12 +3281,12 @@ class PrometheusLogger(CustomLogger): page_size: int, page: int ) -> Tuple[List[LiteLLM_UserTable], Optional[int]]: skip = (page - 1) * page_size - users = await prisma_client.db.litellm_usertable.find_many( + users = await UserRepository(prisma_client).table.find_many( skip=skip, take=page_size, order={"created_at": "desc"}, ) - total_count = await prisma_client.db.litellm_usertable.count() + total_count = await UserRepository(prisma_client).table.count() return users, total_count await self._initialize_budget_metrics( @@ -3306,13 +3309,13 @@ class PrometheusLogger(CustomLogger): async def fetch_orgs(page_size: int, page: int) -> Tuple[list, Optional[int]]: skip = (page - 1) * page_size - orgs = await prisma_client.db.litellm_organizationtable.find_many( + orgs = await OrganizationRepository(prisma_client).table.find_many( skip=skip, take=page_size, order={"created_at": "desc"}, include={"litellm_budget_table": True}, ) - total_count = await prisma_client.db.litellm_organizationtable.count() + total_count = await OrganizationRepository(prisma_client).table.count() return orgs, total_count await self._initialize_budget_metrics( @@ -3380,14 +3383,14 @@ class PrometheusLogger(CustomLogger): try: # Get total user count - total_users = await prisma_client.db.litellm_usertable.count() + total_users = await UserRepository(prisma_client).table.count() self.litellm_total_users_metric.set(total_users) verbose_logger.debug( f"Prometheus: set litellm_total_users to {total_users}" ) # Get total team count - total_teams = await prisma_client.db.litellm_teamtable.count() + total_teams = await TeamRepository(prisma_client).table.count() self.litellm_teams_count_metric.set(total_teams) verbose_logger.debug( f"Prometheus: set litellm_teams_count to {total_teams}" diff --git a/litellm/llms/litellm_proxy/skills/handler.py b/litellm/llms/litellm_proxy/skills/handler.py index 37aabd8b47..7b259c1ed6 100644 --- a/litellm/llms/litellm_proxy/skills/handler.py +++ b/litellm/llms/litellm_proxy/skills/handler.py @@ -17,6 +17,7 @@ from litellm.proxy.common_utils.resource_ownership import ( is_proxy_admin, user_can_access_resource_owner, ) +from litellm.repositories.table_repositories import SkillsRepository # Skills are looked up on every chat completion that has skills enabled # (`SkillsInjectionHook` calls ``fetch_skill_from_db``). 60s LRU/TTL cache @@ -107,7 +108,7 @@ class LiteLLMSkillsHandler: f"LiteLLMSkillsHandler: Creating skill {skill_id} with title={data.display_title}" ) - new_skill = await prisma_client.db.litellm_skillstable.create(data=skill_data) + new_skill = await SkillsRepository(prisma_client).table.create(data=skill_data) return _prisma_skill_to_litellm(new_skill) @staticmethod @@ -133,7 +134,7 @@ class LiteLLMSkillsHandler: return [] find_many_kwargs["where"] = {"created_by": {"in": owner_scopes}} - skills = await prisma_client.db.litellm_skillstable.find_many( + skills = await SkillsRepository(prisma_client).table.find_many( **find_many_kwargs ) return [_prisma_skill_to_litellm(s) for s in skills] @@ -150,7 +151,7 @@ class LiteLLMSkillsHandler: return cached prisma_client = await LiteLLMSkillsHandler._get_prisma_client() - skill = await prisma_client.db.litellm_skillstable.find_unique( + skill = await SkillsRepository(prisma_client).table.find_unique( where={"skill_id": skill_id} ) _SKILL_CACHE.set_cache( @@ -189,7 +190,7 @@ class LiteLLMSkillsHandler: ): raise ValueError(f"Skill not found: {skill_id}") - await prisma_client.db.litellm_skillstable.delete(where={"skill_id": skill_id}) + await SkillsRepository(prisma_client).table.delete(where={"skill_id": skill_id}) _SKILL_CACHE.set_cache(skill_id, _NEGATIVE_SKILL_SENTINEL) return {"id": skill_id, "type": "skill_deleted"} diff --git a/litellm/models/__init__.py b/litellm/models/__init__.py new file mode 100644 index 0000000000..7e2d2c0ed9 --- /dev/null +++ b/litellm/models/__init__.py @@ -0,0 +1,66 @@ +""" +Domain models for LiteLLM backend. +""" + +from litellm.models.access_group import LiteLLM_AccessGroupTable +from litellm.models.budget import ( + LiteLLM_BudgetTable, + LiteLLM_BudgetTableFull, + LiteLLM_TeamMemberTable, +) +from litellm.models.config import LiteLLM_Config +from litellm.models.credentials import ( + CreateCredentialItem, + CredentialBase, + CredentialItem, +) +from litellm.models.end_user import LiteLLM_EndUserTable +from litellm.models.managed_files import ( + LiteLLM_ManagedFileTable, + LiteLLM_ManagedObjectTable, + LiteLLM_ManagedVectorStoresTable, + LiteLLM_ManagedVectorStoreTable, +) +from litellm.models.mcp_server import LiteLLM_MCPServerTable +from litellm.models.model import LiteLLM_ProxyModelTable +from litellm.models.object_permission import LiteLLM_ObjectPermissionTable +from litellm.models.organization import LiteLLM_OrganizationTable +from litellm.models.organization_membership import LiteLLM_OrganizationMembershipTable +from litellm.models.project import LiteLLM_ProjectTable +from litellm.models.skills import LiteLLM_SkillsTable +from litellm.models.spend_logs import LiteLLM_ErrorLogs, LiteLLM_SpendLogs +from litellm.models.tag import LiteLLM_TagTable +from litellm.models.team import LiteLLM_TeamTable +from litellm.models.team_membership import LiteLLM_TeamMembership +from litellm.models.user import LiteLLM_UserTable +from litellm.models.verification_token import LiteLLM_VerificationToken + +__all__ = [ + "LiteLLM_AccessGroupTable", + "LiteLLM_BudgetTable", + "LiteLLM_BudgetTableFull", + "LiteLLM_TeamMemberTable", + "LiteLLM_Config", + "CredentialBase", + "CredentialItem", + "CreateCredentialItem", + "LiteLLM_EndUserTable", + "LiteLLM_ManagedFileTable", + "LiteLLM_ManagedObjectTable", + "LiteLLM_ManagedVectorStoreTable", + "LiteLLM_ManagedVectorStoresTable", + "LiteLLM_MCPServerTable", + "LiteLLM_ProxyModelTable", + "LiteLLM_ObjectPermissionTable", + "LiteLLM_OrganizationTable", + "LiteLLM_OrganizationMembershipTable", + "LiteLLM_ProjectTable", + "LiteLLM_SkillsTable", + "LiteLLM_ErrorLogs", + "LiteLLM_SpendLogs", + "LiteLLM_TagTable", + "LiteLLM_TeamTable", + "LiteLLM_TeamMembership", + "LiteLLM_UserTable", + "LiteLLM_VerificationToken", +] diff --git a/litellm/models/access_group.py b/litellm/models/access_group.py new file mode 100644 index 0000000000..682e779e53 --- /dev/null +++ b/litellm/models/access_group.py @@ -0,0 +1,26 @@ +""" +Access group table model. + +Canonical definition for ``litellm_accessgrouptable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from datetime import datetime +from typing import List, Optional + +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_AccessGroupTable(LiteLLMPydanticObjectBase): + access_group_id: str + access_group_name: str + description: Optional[str] = None + access_model_names: List[str] = [] + access_mcp_server_ids: List[str] = [] + access_agent_ids: List[str] = [] + assigned_team_ids: List[str] = [] + assigned_key_ids: List[str] = [] + created_at: Optional[datetime] = None + created_by: Optional[str] = None + updated_at: Optional[datetime] = None + updated_by: Optional[str] = None diff --git a/litellm/models/base.py b/litellm/models/base.py new file mode 100644 index 0000000000..01981297bd --- /dev/null +++ b/litellm/models/base.py @@ -0,0 +1,38 @@ +""" +Base model class for domain models. +""" + +from datetime import datetime +from typing import Any, Dict, Optional + +from pydantic import BaseModel, ConfigDict + + +class DomainModel(BaseModel): + """Base class for all domain models.""" + + model_config = ConfigDict( + from_attributes=True, + protected_namespaces=(), + extra="ignore", + ) + + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + @classmethod + def from_db_record(cls, record: Any) -> "DomainModel": + """Create a domain model from a database record.""" + if record is None: + raise ValueError("Cannot create domain model from None record") + if isinstance(record, dict): + return cls(**record) + if hasattr(record, "model_dump") and callable(record.model_dump): + return cls(**record.model_dump()) + if hasattr(record, "dict") and callable(record.dict): + return cls(**record.dict()) + return cls(**dict(record)) + + def to_db_dict(self, exclude_unset: bool = False) -> Dict[str, Any]: + """Convert domain model to a dictionary for database operations.""" + return self.model_dump(exclude_none=True, exclude_unset=exclude_unset) diff --git a/litellm/models/budget.py b/litellm/models/budget.py new file mode 100644 index 0000000000..e7dfe2f8fb --- /dev/null +++ b/litellm/models/budget.py @@ -0,0 +1,56 @@ +""" +Budget table model. + +Canonical definition for ``litellm_budgettable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from datetime import datetime +from typing import List, Optional + +from pydantic import ConfigDict + +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase): + """Represents user-controllable params for a LiteLLM_BudgetTable record. + + Budget-write paths use `model_fields.keys()` on this class as an allowlist + for user input. Keep server-managed fields (e.g. `budget_reset_at`) on + `LiteLLM_BudgetTableFull` so they aren't user-settable. + """ + + budget_id: Optional[str] = None + soft_budget: Optional[float] = None + max_budget: Optional[float] = None + max_parallel_requests: Optional[int] = None + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + model_max_budget: Optional[dict] = None + budget_duration: Optional[str] = None + allowed_models: Optional[List[str]] = ( + None # per-member model scope; empty = inherit team models + ) + + model_config = ConfigDict(protected_namespaces=()) + + +class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable): + """LiteLLM_BudgetTable + server-managed fields returned on API responses.""" + + budget_reset_at: Optional[datetime] = None + created_at: datetime + + +class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable): + """ + Used to track spend of a user_id within a team_id + """ + + spend: Optional[float] = None + user_id: Optional[str] = None + team_id: Optional[str] = None + budget_id: Optional[str] = None + + model_config = ConfigDict(protected_namespaces=()) diff --git a/litellm/models/config.py b/litellm/models/config.py new file mode 100644 index 0000000000..99b5c5692f --- /dev/null +++ b/litellm/models/config.py @@ -0,0 +1,15 @@ +""" +Config table model. + +Canonical definition for ``litellm_config``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from typing import Dict + +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_Config(LiteLLMPydanticObjectBase): + param_name: str + param_value: Dict diff --git a/litellm/models/credentials.py b/litellm/models/credentials.py new file mode 100644 index 0000000000..b74ea055d2 --- /dev/null +++ b/litellm/models/credentials.py @@ -0,0 +1,31 @@ +""" +Credential table models. + +These are the canonical credential types for the proxy. They live in the model +layer; ``litellm.types.utils`` re-exports them for backwards compatibility. +""" + +from typing import Optional + +from pydantic import BaseModel, model_validator + + +class CredentialBase(BaseModel): + credential_name: str + credential_info: dict + + +class CredentialItem(CredentialBase): + credential_values: dict + + +class CreateCredentialItem(CredentialBase): + credential_values: Optional[dict] = None + model_id: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def check_credential_params(cls, values): + if not values.get("credential_values") and not values.get("model_id"): + raise ValueError("Either credential_values or model_id must be set") + return values diff --git a/litellm/models/end_user.py b/litellm/models/end_user.py new file mode 100644 index 0000000000..15fd03ec2c --- /dev/null +++ b/litellm/models/end_user.py @@ -0,0 +1,35 @@ +""" +End-user table model. + +Canonical definition for ``litellm_endusertable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from typing import Literal, Optional + +from pydantic import ConfigDict, model_validator + +from litellm.models.budget import LiteLLM_BudgetTable +from litellm.models.object_permission import LiteLLM_ObjectPermissionTable +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase): + user_id: str + blocked: bool + alias: Optional[str] = None + spend: float = 0.0 + allowed_model_region: Optional[Literal["eu", "us"]] = None + default_model: Optional[str] = None + litellm_budget_table: Optional[LiteLLM_BudgetTable] = None + object_permission_id: Optional[str] = None + object_permission: Optional[LiteLLM_ObjectPermissionTable] = None + + @model_validator(mode="before") + @classmethod + def set_model_info(cls, values): + if values.get("spend") is None: + values.update({"spend": 0.0}) + return values + + model_config = ConfigDict(protected_namespaces=()) diff --git a/litellm/models/managed_files.py b/litellm/models/managed_files.py new file mode 100644 index 0000000000..2415476886 --- /dev/null +++ b/litellm/models/managed_files.py @@ -0,0 +1,62 @@ +""" +Managed file, object, and vector store table models. + +Canonical definitions for the ``litellm_managed*`` tables. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional, Union + +from litellm.types.llms.base import LiteLLMPydanticObjectBase +from litellm.types.llms.openai import OpenAIFileObject, ResponsesAPIResponse +from litellm.types.utils import LiteLLMBatch, LiteLLMFineTuningJob + + +class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase): + unified_file_id: str + file_object: Optional[OpenAIFileObject] = None + model_mappings: Dict[str, str] + flat_model_file_ids: List[str] + created_by: Optional[str] = None + team_id: Optional[str] = None + updated_by: Optional[str] = None + storage_backend: Optional[str] = None + storage_url: Optional[str] = None + + +class LiteLLM_ManagedObjectTable(LiteLLMPydanticObjectBase): + unified_object_id: str + model_object_id: str + file_purpose: Literal["batch", "fine-tune", "response", "container"] + file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob, ResponsesAPIResponse] + created_by: Optional[str] = None + team_id: Optional[str] = None + + +class LiteLLM_ManagedVectorStoreTable(LiteLLMPydanticObjectBase): + """Table for managing vector stores with target_model_names support.""" + + unified_resource_id: str + resource_object: Optional[Any] = None + model_mappings: Dict[str, str] + flat_model_resource_ids: List[str] + created_by: Optional[str] = None + team_id: Optional[str] = None + updated_by: Optional[str] = None + storage_backend: Optional[str] = None + storage_url: Optional[str] = None + + +class LiteLLM_ManagedVectorStoresTable(LiteLLMPydanticObjectBase): + vector_store_id: str + custom_llm_provider: str + vector_store_name: Optional[str] + vector_store_description: Optional[str] + vector_store_metadata: Optional[Dict[str, Any]] + created_at: Optional[datetime] + updated_at: Optional[datetime] + litellm_credential_name: Optional[str] + litellm_params: Optional[Dict[str, Any]] + team_id: Optional[str] + user_id: Optional[str] diff --git a/litellm/models/mcp_server.py b/litellm/models/mcp_server.py new file mode 100644 index 0000000000..3d03eff6df --- /dev/null +++ b/litellm/models/mcp_server.py @@ -0,0 +1,103 @@ +""" +MCP server table model. + +Canonical definition for ``litellm_mcpservertable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +import enum +from datetime import datetime +from typing import Dict, List, Literal, Optional + +from pydantic import Field + +from litellm.types.llms.base import LiteLLMPydanticObjectBase +from litellm.types.mcp import MCPAuthType, MCPCredentials, MCPTransportType +from litellm.types.mcp_server.mcp_server_manager import MCPInfo + + +class MCPEnvVarScope(str, enum.Enum): + """Scope for an MCP server environment variable. + + - ``global``: value is provided by the admin and used for all users. + - ``user``: each user must provide their own value via the per-user + env-var endpoint. The admin-supplied ``value`` is treated as a + placeholder/hint and is not used at request time. + """ + + global_ = "global" + user = "user" + + +class MCPEnvVar(LiteLLMPydanticObjectBase): + """One environment variable for an MCP server. + + Variables can be interpolated into ``static_headers`` using ``${NAME}`` + syntax. ``scope=global`` values are stored on the server. ``scope=user`` + values are stored per-user in ``LiteLLM_MCPUserEnvVars`` and supplied by + each user. + """ + + name: str + value: str = "" + scope: MCPEnvVarScope = MCPEnvVarScope.global_ + description: Optional[str] = None + + +class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase): + """Represents a LiteLLM_MCPServerTable record""" + + server_id: str + server_name: Optional[str] = None + alias: Optional[str] = None + description: Optional[str] = None + url: Optional[str] = None + spec_path: Optional[str] = None + transport: MCPTransportType + auth_type: Optional[MCPAuthType] = None + credentials: Optional[MCPCredentials] = None + instructions: Optional[str] = None + created_at: Optional[datetime] = None + created_by: Optional[str] = None + updated_at: Optional[datetime] = None + updated_by: Optional[str] = None + teams: List[Dict[str, Optional[str]]] = Field(default_factory=list) + mcp_access_groups: List[str] = Field(default_factory=list) + allowed_tools: List[str] = Field(default_factory=list) + tool_name_to_display_name: Optional[Dict[str, str]] = None + tool_name_to_description: Optional[Dict[str, str]] = None + extra_headers: List[str] = Field(default_factory=list) + mcp_info: Optional[MCPInfo] = None + static_headers: Optional[Dict[str, str]] = None + env_vars: Optional[List[MCPEnvVar]] = None + status: Optional[Literal["healthy", "unhealthy", "unknown"]] = Field( + default="unknown", + description="Health status: 'healthy', 'unhealthy', 'unknown'", + ) + last_health_check: Optional[datetime] = None + health_check_error: Optional[str] = None + command: Optional[str] = None + args: List[str] = Field(default_factory=list) + env: Dict[str, str] = Field(default_factory=dict) + authorization_url: Optional[str] = None + token_url: Optional[str] = None + registration_url: Optional[str] = None + oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = None + allow_all_keys: bool = False + available_on_public_internet: bool = True + delegate_auth_to_upstream: bool = False + oauth_passthrough: bool = False + is_byok: bool = False + byok_description: List[str] = Field(default_factory=list) + byok_api_key_help_url: Optional[str] = None + has_user_credential: Optional[bool] = None + source_url: Optional[str] = None + timeout: Optional[float] = None + approval_status: Optional[str] = Field( + default="active", + description="Approval status: 'pending_review', 'active', 'rejected'", + ) + submitted_by: Optional[str] = None + submitted_at: Optional[datetime] = None + reviewed_at: Optional[datetime] = None + review_notes: Optional[str] = None diff --git a/litellm/models/model.py b/litellm/models/model.py new file mode 100644 index 0000000000..7657e4d30f --- /dev/null +++ b/litellm/models/model.py @@ -0,0 +1,59 @@ +""" +Proxy model table model. + +Canonical definition for ``litellm_proxymodeltable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +import json +from datetime import datetime +from typing import Optional + +from pydantic import ConfigDict, model_validator + +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_ProxyModelTable(LiteLLMPydanticObjectBase): + model_id: str + model_name: str + litellm_params: dict + model_info: Optional[dict] = None + blocked: bool = False + created_at: Optional[datetime] = None + created_by: Optional[str] = None + updated_at: Optional[datetime] = None + updated_by: Optional[str] = None + + model_config = ConfigDict(protected_namespaces=()) + + @model_validator(mode="before") + @classmethod + def check_potential_json_str(cls, values): + if isinstance(values.get("litellm_params"), str): + try: + values["litellm_params"] = json.loads(values["litellm_params"]) + except json.JSONDecodeError: + pass + if isinstance(values.get("model_info"), str): + try: + values["model_info"] = json.loads(values["model_info"]) + except json.JSONDecodeError: + pass + return values + + @property + def is_blocked(self) -> bool: + return self.blocked + + @property + def team_id(self) -> Optional[str]: + if self.model_info: + return self.model_info.get("team_id") + return None + + @property + def team_public_model_name(self) -> Optional[str]: + if self.model_info: + return self.model_info.get("team_public_model_name") + return None diff --git a/litellm/models/object_permission.py b/litellm/models/object_permission.py new file mode 100644 index 0000000000..6c0d100046 --- /dev/null +++ b/litellm/models/object_permission.py @@ -0,0 +1,26 @@ +""" +Object permission table model. + +Canonical definition for ``litellm_objectpermissiontable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from typing import Dict, List, Optional + +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_ObjectPermissionTable(LiteLLMPydanticObjectBase): + """Represents a LiteLLM_ObjectPermissionTable record""" + + object_permission_id: str + mcp_servers: Optional[List[str]] = [] + mcp_access_groups: Optional[List[str]] = [] + mcp_tool_permissions: Optional[Dict[str, List[str]]] = None + vector_stores: Optional[List[str]] = [] + agents: Optional[List[str]] = [] + agent_access_groups: Optional[List[str]] = [] + models: Optional[List[str]] = [] + mcp_toolsets: Optional[List[str]] = None + blocked_tools: Optional[List[str]] = [] + search_tools: Optional[List[str]] = [] diff --git a/litellm/models/organization.py b/litellm/models/organization.py new file mode 100644 index 0000000000..8b2d95c3e0 --- /dev/null +++ b/litellm/models/organization.py @@ -0,0 +1,31 @@ +""" +Organization table model. + +Canonical definition for ``litellm_organizationtable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from typing import List, Optional + +from litellm.models.budget import LiteLLM_BudgetTable +from litellm.models.object_permission import LiteLLM_ObjectPermissionTable +from litellm.models.user import LiteLLM_UserTable +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_OrganizationTable(LiteLLMPydanticObjectBase): + """Represents user-controllable params for a LiteLLM_OrganizationTable record""" + + organization_id: Optional[str] = None + organization_alias: Optional[str] = None + budget_id: str + spend: float = 0.0 + metadata: Optional[dict] = None + models: List[str] = [] + model_spend: Optional[dict] = {} + created_by: str + updated_by: str + users: Optional[List[LiteLLM_UserTable]] = None + litellm_budget_table: Optional[LiteLLM_BudgetTable] = None + object_permission: Optional[LiteLLM_ObjectPermissionTable] = None + object_permission_id: Optional[str] = None diff --git a/litellm/models/organization_membership.py b/litellm/models/organization_membership.py new file mode 100644 index 0000000000..9957c0c21a --- /dev/null +++ b/litellm/models/organization_membership.py @@ -0,0 +1,40 @@ +""" +Organization membership table model. + +Canonical definition for ``litellm_organizationmembership``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from datetime import datetime +from typing import Any, Optional + +from pydantic import ConfigDict, model_validator + +from litellm.models.budget import LiteLLM_BudgetTable +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): + """Tracks which organizations a user belongs to and their spend within it.""" + + user_id: str + organization_id: str + user_role: Optional[str] = None + spend: float = 0.0 + budget_id: Optional[str] = None + created_at: datetime + updated_at: datetime + user: Optional[Any] = None + litellm_budget_table: Optional[LiteLLM_BudgetTable] = None + user_email: Optional[str] = None + + model_config = ConfigDict(protected_namespaces=()) + + @model_validator(mode="after") + def populate_user_email(self) -> "LiteLLM_OrganizationMembershipTable": + if self.user_email is None and self.user is not None: + if isinstance(self.user, dict): + self.user_email = self.user.get("user_email") + else: + self.user_email = getattr(self.user, "user_email", None) + return self diff --git a/litellm/models/project.py b/litellm/models/project.py new file mode 100644 index 0000000000..083c7ee3cc --- /dev/null +++ b/litellm/models/project.py @@ -0,0 +1,41 @@ +""" +Project table model. + +Canonical definition for ``litellm_projecttable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from datetime import datetime +from typing import List, Optional + +from litellm.models.budget import LiteLLM_BudgetTable +from litellm.models.object_permission import LiteLLM_ObjectPermissionTable +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_ProjectTable(LiteLLMPydanticObjectBase): + """Database model representation for project""" + + project_id: str + project_alias: Optional[str] = None + description: Optional[str] = None + team_id: Optional[str] = None + budget_id: Optional[str] = None + metadata: Optional[dict] = None + models: List[str] = [] + spend: float = 0.0 + model_spend: Optional[dict] = None + model_rpm_limit: Optional[dict] = None + model_tpm_limit: Optional[dict] = None + blocked: bool = False + object_permission_id: Optional[str] = None + created_by: Optional[str] = None + updated_by: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + litellm_budget_table: Optional[LiteLLM_BudgetTable] = None + object_permission: Optional[LiteLLM_ObjectPermissionTable] = None + + @property + def is_blocked(self) -> bool: + return self.blocked diff --git a/litellm/models/skills.py b/litellm/models/skills.py new file mode 100644 index 0000000000..62091c0ca0 --- /dev/null +++ b/litellm/models/skills.py @@ -0,0 +1,30 @@ +""" +Skills table model. + +Canonical definition for ``litellm_skillstable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from datetime import datetime +from typing import Any, Dict, Optional + +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_SkillsTable(LiteLLMPydanticObjectBase): + """Represents a LiteLLM_SkillsTable record""" + + skill_id: str + display_title: Optional[str] = None + description: Optional[str] = None + instructions: Optional[str] = None + source: str = "custom" + latest_version: Optional[str] = None + file_content: Optional[bytes] = None + file_name: Optional[str] = None + file_type: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + created_at: Optional[datetime] = None + created_by: Optional[str] = None + updated_at: Optional[datetime] = None + updated_by: Optional[str] = None diff --git a/litellm/models/spend_logs.py b/litellm/models/spend_logs.py new file mode 100644 index 0000000000..96bd328c3c --- /dev/null +++ b/litellm/models/spend_logs.py @@ -0,0 +1,50 @@ +""" +Spend and error log table models. + +Canonical definitions for ``litellm_spendlogs`` and ``litellm_errorlogs``. +Re-exported from ``litellm.proxy._types`` for backwards compatibility. +""" + +from datetime import datetime +from typing import Optional, Union + +from pydantic import Json + +from litellm._uuid import uuid +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_SpendLogs(LiteLLMPydanticObjectBase): + request_id: str + api_key: str + model: Optional[str] = "" + api_base: Optional[str] = "" + call_type: str + spend: Optional[float] = 0.0 + total_tokens: Optional[int] = 0 + prompt_tokens: Optional[int] = 0 + completion_tokens: Optional[int] = 0 + startTime: Union[str, datetime, None] + endTime: Union[str, datetime, None] + user: Optional[str] = "" + metadata: Optional[Json] = {} + cache_hit: Optional[str] = "False" + cache_key: Optional[str] = None + request_tags: Optional[Json] = None + requester_ip_address: Optional[str] = None + messages: Optional[Union[str, list, dict]] + response: Optional[Union[str, list, dict]] + + +class LiteLLM_ErrorLogs(LiteLLMPydanticObjectBase): + request_id: Optional[str] = str(uuid.uuid4()) + api_base: Optional[str] = "" + model_group: Optional[str] = "" + litellm_model_name: Optional[str] = "" + model_id: Optional[str] = "" + request_kwargs: Optional[dict] = {} + exception_type: Optional[str] = "" + status_code: Optional[str] = "" + exception_string: Optional[str] = "" + startTime: Union[str, datetime, None] + endTime: Union[str, datetime, None] diff --git a/litellm/models/tag.py b/litellm/models/tag.py new file mode 100644 index 0000000000..02d8f58916 --- /dev/null +++ b/litellm/models/tag.py @@ -0,0 +1,36 @@ +""" +Tag table model. + +Canonical definition for ``litellm_tagtable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from datetime import datetime +from typing import List, Optional + +from pydantic import model_validator + +from litellm.models.budget import LiteLLM_BudgetTable +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_TagTable(LiteLLMPydanticObjectBase): + tag_name: str + description: Optional[str] = None + models: List[str] = [] + model_info: Optional[dict] = None + spend: float = 0.0 + budget_id: Optional[str] = None + litellm_budget_table: Optional[LiteLLM_BudgetTable] = None + created_at: Optional[datetime] = None + created_by: Optional[str] = None + updated_at: Optional[datetime] = None + + @model_validator(mode="before") + @classmethod + def set_model_info(cls, values): + if values.get("spend") is None: + values.update({"spend": 0.0}) + if values.get("models") is None: + values.update({"models": []}) + return values diff --git a/litellm/models/team.py b/litellm/models/team.py new file mode 100644 index 0000000000..aa0798955f --- /dev/null +++ b/litellm/models/team.py @@ -0,0 +1,154 @@ +""" +Team table models. + +Canonical definitions for ``litellm_teamtable`` (plus the shared Member and +budget-window value types and the team-model alias table). Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +import json +from datetime import datetime +from typing import List, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from litellm.models.object_permission import LiteLLM_ObjectPermissionTable +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class MemberBase(LiteLLMPydanticObjectBase): + user_id: Optional[str] = Field( + default=None, + description="The unique ID of the user to add. Either user_id or user_email must be provided", + ) + user_email: Optional[str] = Field( + default=None, + description="The email address of the user to add. Either user_id or user_email must be provided", + ) + + @model_validator(mode="before") + @classmethod + def check_user_info(cls, values): + if not isinstance(values, dict): + raise ValueError("input needs to be a dictionary") + if values.get("user_id") is None and values.get("user_email") is None: + raise ValueError("Either user id or user email must be provided") + return values + + +class Member(MemberBase): + role: Literal["admin", "user"] = Field( + description="The role of the user within the team. 'admin' users can manage team settings and members, 'user' is a regular team member" + ) + + +class BudgetLimitEntry(LiteLLMPydanticObjectBase): + """A single budget window with its own limit and independent reset schedule.""" + + budget_duration: str + max_budget: float + reset_at: Optional[datetime] = None + + +class LiteLLM_ModelTable(LiteLLMPydanticObjectBase): + id: Optional[int] = None + model_aliases: Optional[Union[str, dict]] = None + created_by: str + updated_by: str + team: Optional["LiteLLM_TeamTable"] = None + + model_config = ConfigDict(protected_namespaces=()) + + +class TeamBase(LiteLLMPydanticObjectBase): + team_alias: Optional[str] = None + team_id: Optional[str] = None + organization_id: Optional[str] = None + admins: list = [] + members: list = [] + members_with_roles: List[Member] = [] + team_member_permissions: Optional[List[str]] = None + metadata: Optional[dict] = None + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + max_budget: Optional[float] = None + soft_budget: Optional[float] = None + budget_duration: Optional[str] = None + budget_limits: Optional[List[BudgetLimitEntry]] = None + models: list = [] + blocked: bool = False + router_settings: Optional[dict] = None + access_group_ids: Optional[List[str]] = None + default_team_member_models: Optional[List[str]] = None + + +class LiteLLM_TeamTable(TeamBase): + team_id: str # type: ignore + spend: Optional[float] = None + max_parallel_requests: Optional[int] = None + budget_duration: Optional[str] = None + budget_reset_at: Optional[datetime] = None + model_id: Optional[int] = None + model_spend: Optional[dict] = {} + model_max_budget: Optional[dict] = {} + policies: Optional[List[str]] = None + allow_team_guardrail_config: Optional[bool] = False + litellm_model_table: Optional[LiteLLM_ModelTable] = None + object_permission: Optional[LiteLLM_ObjectPermissionTable] = None + object_permission_id: Optional[str] = None + updated_at: Optional[datetime] = None + created_at: Optional[datetime] = None + + model_config = ConfigDict(protected_namespaces=()) + + @model_validator(mode="before") + @classmethod + def set_model_info(cls, values): + dict_fields = [ + "metadata", + "aliases", + "config", + "permissions", + "model_max_budget", + "model_aliases", + "router_settings", + "budget_limits", + ] + + if isinstance(values, BaseModel): + values = values.model_dump() + + if ( + isinstance(values.get("members_with_roles"), dict) + and not values["members_with_roles"] + ): + values["members_with_roles"] = [] + + for field in dict_fields: + value = values.get(field) + if value is not None and isinstance(value, str): + try: + values[field] = json.loads(value) + except json.JSONDecodeError: + raise ValueError(f"Field {field} should be a valid dictionary") + + return values + + +class LiteLLM_TeamTableCachedObj(LiteLLM_TeamTable): + last_refreshed_at: Optional[float] = None + + +class LiteLLM_DeletedTeamTable(LiteLLM_TeamTable): + """Audit record for deleted teams; mirrors the team plus deletion metadata.""" + + id: Optional[str] = None + deleted_at: Optional[datetime] = None + deleted_by: Optional[str] = None + deleted_by_api_key: Optional[str] = None + litellm_changed_by: Optional[str] = None + + model_config = ConfigDict(protected_namespaces=()) + + +LiteLLM_ModelTable.model_rebuild() diff --git a/litellm/models/team_membership.py b/litellm/models/team_membership.py new file mode 100644 index 0000000000..d0a1308ce7 --- /dev/null +++ b/litellm/models/team_membership.py @@ -0,0 +1,32 @@ +""" +Team membership table model. + +Canonical definition for ``litellm_teammembership``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from typing import Optional, Union + +from litellm.models.budget import LiteLLM_BudgetTable, LiteLLM_BudgetTableFull +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase): + user_id: str + team_id: str + budget_id: Optional[str] = None + spend: Optional[float] = 0.0 + total_spend: Optional[float] = 0.0 + litellm_budget_table: Optional[ + Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable] + ] = None + + def safe_get_team_member_rpm_limit(self) -> Optional[int]: + if self.litellm_budget_table is not None: + return self.litellm_budget_table.rpm_limit + return None + + def safe_get_team_member_tpm_limit(self) -> Optional[int]: + if self.litellm_budget_table is not None: + return self.litellm_budget_table.tpm_limit + return None diff --git a/litellm/models/user.py b/litellm/models/user.py new file mode 100644 index 0000000000..cd7e9db4ae --- /dev/null +++ b/litellm/models/user.py @@ -0,0 +1,70 @@ +""" +User table model. + +Canonical definition for ``litellm_usertable``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from datetime import datetime +from typing import Dict, List, Optional + +from pydantic import ConfigDict, Field, model_validator + +from litellm.models.object_permission import LiteLLM_ObjectPermissionTable +from litellm.models.organization_membership import ( + LiteLLM_OrganizationMembershipTable, +) +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_UserTable(LiteLLMPydanticObjectBase): + user_id: str + user_alias: Optional[str] = None + team_id: Optional[str] = None + sso_user_id: Optional[str] = None + organization_id: Optional[str] = None + object_permission_id: Optional[str] = None + password: Optional[str] = Field(default=None, exclude=True) + teams: List[str] = [] + user_role: Optional[str] = None + max_budget: Optional[float] = None + spend: float = 0.0 + user_email: Optional[str] = None + models: list = [] + metadata: Optional[dict] = None + max_parallel_requests: Optional[int] = None + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + budget_duration: Optional[str] = None + budget_reset_at: Optional[datetime] = None + allowed_cache_controls: List[str] = [] + policies: List[str] = [] + model_spend: Optional[Dict] = {} + model_max_budget: Optional[Dict] = {} + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None + object_permission: Optional[LiteLLM_ObjectPermissionTable] = None + + model_config = ConfigDict(protected_namespaces=()) + + @model_validator(mode="before") + @classmethod + def set_model_info(cls, values): + if values.get("spend") is None: + values.update({"spend": 0.0}) + if values.get("models") is None: + values.update({"models": []}) + if values.get("teams") is None: + values.update({"teams": []}) + return values + + def is_over_budget(self) -> bool: + if self.max_budget is None: + return False + return self.spend >= self.max_budget + + def has_model_access(self, model_name: str) -> bool: + if not self.models: + return True + return model_name in self.models diff --git a/litellm/models/verification_token.py b/litellm/models/verification_token.py new file mode 100644 index 0000000000..8bddd1c161 --- /dev/null +++ b/litellm/models/verification_token.py @@ -0,0 +1,74 @@ +""" +Verification token table model. + +Canonical definition for ``litellm_verificationtoken``. Re-exported from +``litellm.proxy._types`` for backwards compatibility. +""" + +from datetime import datetime +from typing import Dict, List, Optional, Union + +from pydantic import ConfigDict + +from litellm.models.object_permission import LiteLLM_ObjectPermissionTable +from litellm.types.llms.base import LiteLLMPydanticObjectBase + + +class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase): + token: Optional[str] = None + key_name: Optional[str] = None + key_alias: Optional[str] = None + spend: float = 0.0 + max_budget: Optional[float] = None + expires: Optional[Union[str, datetime]] = None + models: List = [] + aliases: Dict = {} + config: Dict = {} + user_id: Optional[str] = None + team_id: Optional[str] = None + agent_id: Optional[str] = None + project_id: Optional[str] = None + max_parallel_requests: Optional[int] = None + metadata: Dict = {} + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + budget_duration: Optional[str] = None + budget_reset_at: Optional[datetime] = None + allowed_cache_controls: Optional[list] = [] + allowed_routes: Optional[list] = [] + permissions: Dict = {} + model_spend: Dict = {} + model_max_budget: Dict = {} + soft_budget_cooldown: bool = False + blocked: Optional[bool] = None + litellm_budget_table: Optional[dict] = None + budget_id: Optional[str] = None + org_id: Optional[str] = None # org id for a given key + created_at: Optional[datetime] = None + created_by: Optional[str] = None + updated_at: Optional[datetime] = None + updated_by: Optional[str] = None + last_active: Optional[datetime] = None + object_permission_id: Optional[str] = None + object_permission: Optional[LiteLLM_ObjectPermissionTable] = None + access_group_ids: Optional[List[str]] = None + rotation_count: Optional[int] = 0 + auto_rotate: Optional[bool] = False + rotation_interval: Optional[str] = None + last_rotation_at: Optional[datetime] = None + key_rotation_at: Optional[datetime] = None + router_settings: Optional[dict] = None + budget_limits: Optional[List[dict]] = None + model_config = ConfigDict(protected_namespaces=()) + + +class LiteLLM_DeletedVerificationToken(LiteLLM_VerificationToken): + """Audit record for deleted keys; mirrors the token plus deletion metadata.""" + + id: Optional[str] = None + deleted_at: Optional[datetime] = None + deleted_by: Optional[str] = None + deleted_by_api_key: Optional[str] = None + litellm_changed_by: Optional[str] = None + + model_config = ConfigDict(protected_namespaces=()) diff --git a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py index 863e6acd41..dcf7660d00 100644 --- a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py +++ b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py @@ -14,8 +14,12 @@ from litellm.proxy._types import ( SpecialHeaders, UserAPIKeyAuth, ) -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.ip_address_utils import IPAddressUtils +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.repositories.table_repositories import ( + AgentsRepository, + MCPServerRepository, +) def _parse_mcp_server_names_from_path( @@ -1445,7 +1449,7 @@ class MCPRequestHandler: return None if object_permission_id is None: - agent_row = await prisma_client.db.litellm_agentstable.find_unique( + agent_row = await AgentsRepository(prisma_client).table.find_unique( where={"agent_id": agent_id}, ) object_permission_id = ( @@ -1600,7 +1604,7 @@ class MCPRequestHandler: server_ids: Set[str] = set() if access_groups and prisma_client is not None: try: - mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many( + mcp_servers = await MCPServerRepository(prisma_client).table.find_many( where={"mcp_access_groups": {"hasSome": access_groups}} ) for server in mcp_servers: diff --git a/litellm/proxy/_experimental/mcp_server/db.py b/litellm/proxy/_experimental/mcp_server/db.py index 0ba0181200..c52752940c 100644 --- a/litellm/proxy/_experimental/mcp_server/db.py +++ b/litellm/proxy/_experimental/mcp_server/db.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid from litellm.constants import MCP_PER_USER_TOKEN_EXPIRY_BUFFER_SECONDS +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.proxy._types import ( LiteLLM_MCPServerTable, LiteLLM_ObjectPermissionTable, @@ -25,8 +26,16 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import ( decrypt_value_helper, encrypt_value_helper, ) -from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.proxy.utils import PrismaClient +from litellm.repositories.object_permission_repository import ObjectPermissionRepository +from litellm.repositories.table_repositories import ( + MCPServerRepository, + MCPUserCredentialsRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.types.llms.custom_http import httpxSpecialProvider from litellm.types.mcp import MCPCredentials @@ -354,7 +363,7 @@ async def get_all_mcp_servers( where: Dict[str, Any] = {} if approval_status is not None: where["approval_status"] = approval_status - mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many( + mcp_servers = await MCPServerRepository(prisma_client).table.find_many( where=where if where else {} ) @@ -380,7 +389,9 @@ async def get_mcp_server( """ Returns the matching mcp server from the db iff exists """ - mcp_server = await prisma_client.db.litellm_mcpservertable.find_unique( + mcp_server: Optional[LiteLLM_MCPServerTable] = await MCPServerRepository( + prisma_client + ).table.find_unique( where={ "server_id": server_id, } @@ -398,12 +409,12 @@ async def get_mcp_servers( """ Returns the matching mcp servers from the db with the server_ids """ - _mcp_servers: List[LiteLLM_MCPServerTable] = ( - await prisma_client.db.litellm_mcpservertable.find_many( - where={ - "server_id": {"in": server_ids}, - } - ) + _mcp_servers: List[LiteLLM_MCPServerTable] = await MCPServerRepository( + prisma_client + ).table.find_many( + where={ + "server_id": {"in": server_ids}, + } ) final_mcp_servers: List[LiteLLM_MCPServerTable] = [] for _mcp_server in _mcp_servers: @@ -420,15 +431,15 @@ async def get_mcp_servers_by_verificationtoken( """ Returns the mcp servers from the db for the verification token """ - verification_token_record: LiteLLM_TeamTable = ( - await prisma_client.db.litellm_verificationtoken.find_unique( - where={ - "token": token, - }, - include={ - "object_permission": True, - }, - ) + verification_token_record: LiteLLM_TeamTable = await VerificationTokenRepository( + prisma_client + ).table.find_unique( + where={ + "token": token, + }, + include={ + "object_permission": True, + }, ) mcp_servers: Optional[List[str]] = [] @@ -446,15 +457,15 @@ async def get_mcp_servers_by_team( """ Returns the mcp servers from the db for the team id """ - team_record: LiteLLM_TeamTable = ( - await prisma_client.db.litellm_teamtable.find_unique( - where={ - "team_id": team_id, - }, - include={ - "object_permission": True, - }, - ) + team_record: LiteLLM_TeamTable = await TeamRepository( + prisma_client + ).table.find_unique( + where={ + "team_id": team_id, + }, + include={ + "object_permission": True, + }, ) mcp_servers: Optional[List[str]] = [] @@ -505,16 +516,16 @@ async def get_objectpermissions_for_mcp_server( """ Get all the object permissions records and the associated team and verficiationtoken records that have access to the mcp server """ - object_permission_records = ( - await prisma_client.db.litellm_objectpermissiontable.find_many( - where={ - "mcp_servers": {"has": mcp_server_id}, - }, - include={ - "teams": True, - "verification_tokens": True, - }, - ) + object_permission_records = await ObjectPermissionRepository( + prisma_client + ).table.find_many( + where={ + "mcp_servers": {"has": mcp_server_id}, + }, + include={ + "teams": True, + "verification_tokens": True, + }, ) return object_permission_records @@ -526,7 +537,7 @@ async def get_virtualkeys_for_mcp_server( """ Get all the virtual keys that have access to the mcp server """ - virtual_keys = await prisma_client.db.litellm_verificationtoken.find_many( + virtual_keys = await VerificationTokenRepository(prisma_client).table.find_many( where={ "mcp_servers": {"has": server_id}, }, @@ -564,7 +575,7 @@ async def delete_mcp_server( Returns the deleted mcp server record if it exists, otherwise None """ - deleted_server = await prisma_client.db.litellm_mcpservertable.delete( + deleted_server = await MCPServerRepository(prisma_client).table.delete( where={ "server_id": server_id, }, @@ -600,7 +611,7 @@ async def create_mcp_server( data_dict["created_by"] = touched_by data_dict["updated_by"] = touched_by - new_mcp_server = await prisma_client.db.litellm_mcpservertable.create( + new_mcp_server = await MCPServerRepository(prisma_client).table.create( data=data_dict # type: ignore ) @@ -635,7 +646,7 @@ async def update_mcp_server( "credentials" in data_dict and data_dict["credentials"] is not None ) if data.auth_type or has_credentials: - existing = await prisma_client.db.litellm_mcpservertable.find_unique( + existing = await MCPServerRepository(prisma_client).table.find_unique( where={"server_id": data.server_id} ) @@ -678,7 +689,7 @@ async def update_mcp_server( # Add audit fields data_dict["updated_by"] = touched_by - updated_mcp_server = await prisma_client.db.litellm_mcpservertable.update( + updated_mcp_server = await MCPServerRepository(prisma_client).table.update( where={"server_id": data.server_id}, data=data_dict # type: ignore ) @@ -691,7 +702,7 @@ async def rotate_mcp_server_credentials_master_key( ): from litellm.litellm_core_utils.safe_json_dumps import safe_dumps - mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many() + mcp_servers = await MCPServerRepository(prisma_client).table.find_many() updated = 0 for mcp_server in mcp_servers: @@ -719,7 +730,7 @@ async def rotate_mcp_server_credentials_master_key( continue update_data["updated_by"] = touched_by - await prisma_client.db.litellm_mcpservertable.update( + await MCPServerRepository(prisma_client).table.update( where={"server_id": mcp_server.server_id}, data=update_data, ) @@ -781,7 +792,7 @@ async def rotate_mcp_user_credentials_master_key( under the new master key. Rows that are unreadable under both paths are logged and skipped so one corrupt row does not abort the rotation. """ - rows = await prisma_client.db.litellm_mcpusercredentials.find_many() + rows = await MCPUserCredentialsRepository(prisma_client).table.find_many() rotated = 0 skipped = 0 for row in rows: @@ -798,7 +809,7 @@ async def rotate_mcp_user_credentials_master_key( re_encrypted = encrypt_value_helper( plaintext, new_encryption_key=new_master_key ) - await prisma_client.db.litellm_mcpusercredentials.update( + await MCPUserCredentialsRepository(prisma_client).table.update( where={ "user_id_server_id": { "user_id": row.user_id, @@ -873,7 +884,7 @@ async def store_user_credential( """Store a user credential for a BYOK MCP server.""" encoded = encrypt_value_helper(credential) - await prisma_client.db.litellm_mcpusercredentials.upsert( + await MCPUserCredentialsRepository(prisma_client).table.upsert( where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}, data={ "create": { @@ -893,7 +904,7 @@ async def get_user_credential( ) -> Optional[str]: """Return credential for a user+server pair, or None.""" - row = await prisma_client.db.litellm_mcpusercredentials.find_unique( + row = await MCPUserCredentialsRepository(prisma_client).table.find_unique( where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}} ) if row is None: @@ -907,7 +918,7 @@ async def has_user_credential( server_id: str, ) -> bool: """Return True if the user has a stored credential for this server.""" - row = await prisma_client.db.litellm_mcpusercredentials.find_unique( + row = await MCPUserCredentialsRepository(prisma_client).table.find_unique( where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}} ) return row is not None @@ -919,7 +930,7 @@ async def delete_user_credential( server_id: str, ) -> None: """Delete the user's stored credential for a BYOK MCP server.""" - await prisma_client.db.litellm_mcpusercredentials.delete( + await MCPUserCredentialsRepository(prisma_client).table.delete( where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}} ) @@ -966,7 +977,7 @@ async def store_user_oauth_credential( # Skip the guard when the caller knows the row is already an OAuth2 credential # (e.g. during token refresh), saving an extra DB round-trip. if not skip_byok_guard: - existing = await prisma_client.db.litellm_mcpusercredentials.find_unique( + existing = await MCPUserCredentialsRepository(prisma_client).table.find_unique( where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}} ) if ( @@ -984,7 +995,7 @@ async def store_user_oauth_credential( ) encoded = encrypt_value_helper(json.dumps(payload)) - await prisma_client.db.litellm_mcpusercredentials.upsert( + await MCPUserCredentialsRepository(prisma_client).table.upsert( where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}, data={ "create": { @@ -1025,7 +1036,7 @@ async def get_user_oauth_credential( ) -> Optional[Dict[str, Any]]: """Return the decoded OAuth2 payload dict for a user+server pair, or None.""" - row = await prisma_client.db.litellm_mcpusercredentials.find_unique( + row = await MCPUserCredentialsRepository(prisma_client).table.find_unique( where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}} ) if row is None: @@ -1039,7 +1050,7 @@ async def list_user_oauth_credentials( ) -> List[Dict[str, Any]]: """Return all OAuth2 credential payloads for a user, tagged with server_id.""" - rows = await prisma_client.db.litellm_mcpusercredentials.find_many( + rows = await MCPUserCredentialsRepository(prisma_client).table.find_many( where={"user_id": user_id} ) results: List[Dict[str, Any]] = [] @@ -1212,7 +1223,7 @@ async def approve_mcp_server( ) -> LiteLLM_MCPServerTable: """Set approval_status=active and record reviewed_at.""" now = datetime.now(timezone.utc) - updated = await prisma_client.db.litellm_mcpservertable.update( + updated = await MCPServerRepository(prisma_client).table.update( where={"server_id": server_id}, data={ "approval_status": MCPApprovalStatus.active, @@ -1240,7 +1251,7 @@ async def reject_mcp_server( } if review_notes is not None: data["review_notes"] = review_notes - updated = await prisma_client.db.litellm_mcpservertable.update( + updated = await MCPServerRepository(prisma_client).table.update( where={"server_id": server_id}, data=data, ) @@ -1257,7 +1268,7 @@ async def get_mcp_submissions( along with a summary count breakdown by approval_status. Mirrors get_guardrail_submissions() from guardrail_endpoints.py. """ - rows = await prisma_client.db.litellm_mcpservertable.find_many( + rows = await MCPServerRepository(prisma_client).table.find_many( where={"submitted_at": {"not": None}}, order={"submitted_at": "desc"}, take=500, # safety cap; paginate if needed in a future iteration diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 7048f5bf7c..85ac6b399f 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -42,8 +42,8 @@ from litellm.constants import ( MCP_TOOL_LISTING_TIMEOUT, ) from litellm.exceptions import BlockedPiiEntityError, GuardrailRaisedException -from litellm.litellm_core_utils.url_utils import SSRFError, async_safe_get from litellm.experimental_mcp_client.client import MCPClient, MCPSigV4Auth +from litellm.litellm_core_utils.url_utils import SSRFError, async_safe_get from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( MCPRequestHandler, @@ -85,6 +85,7 @@ from litellm.proxy._types import ( from litellm.proxy.auth.ip_address_utils import IPAddressUtils from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper from litellm.proxy.utils import ProxyLogging +from litellm.repositories.table_repositories import MCPServerRepository from litellm.types.llms.custom_http import httpxSpecialProvider from litellm.types.mcp import MCPAuth, MCPStdioConfig from litellm.types.mcp_server.mcp_server_manager import ( @@ -3817,7 +3818,7 @@ class MCPServerManager: # Pending/rejected servers are excluded at the DB level so we never load them. from litellm.proxy._experimental.mcp_server.db import LiteLLM_MCPServerTable - raw_rows = await prisma_client.db.litellm_mcpservertable.find_many( + raw_rows = await MCPServerRepository(prisma_client).table.find_many( where={ "OR": [ {"approval_status": None}, diff --git a/litellm/proxy/_experimental/mcp_server/toolset_db.py b/litellm/proxy/_experimental/mcp_server/toolset_db.py index 08ac7dbd33..a996131653 100644 --- a/litellm/proxy/_experimental/mcp_server/toolset_db.py +++ b/litellm/proxy/_experimental/mcp_server/toolset_db.py @@ -4,6 +4,7 @@ from typing import List, Optional from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid from litellm.proxy.utils import PrismaClient +from litellm.repositories.table_repositories import MCPToolsetRepository from litellm.types.mcp_server.mcp_toolset import ( MCPToolset, NewMCPToolsetRequest, @@ -30,7 +31,7 @@ async def create_mcp_toolset( data_dict["tools"] = json.dumps(data_dict.get("tools", [])) data_dict["created_by"] = touched_by data_dict["updated_by"] = touched_by - row = await prisma_client.db.litellm_mcptoolsettable.create(data=data_dict) + row = await MCPToolsetRepository(prisma_client).table.create(data=data_dict) return _toolset_from_row(row) @@ -38,7 +39,7 @@ async def get_mcp_toolset( prisma_client: PrismaClient, toolset_id: str, ) -> Optional[MCPToolset]: - row = await prisma_client.db.litellm_mcptoolsettable.find_unique( + row = await MCPToolsetRepository(prisma_client).table.find_unique( where={"toolset_id": toolset_id} ) if row is None: @@ -54,7 +55,7 @@ async def list_mcp_toolsets( where = {} if toolset_ids is not None: where = {"toolset_id": {"in": toolset_ids}} - rows = await prisma_client.db.litellm_mcptoolsettable.find_many(where=where) + rows = await MCPToolsetRepository(prisma_client).table.find_many(where=where) return [_toolset_from_row(r) for r in rows] except Exception as e: verbose_proxy_logger.warning( @@ -69,7 +70,7 @@ async def get_mcp_toolset_by_name( prisma_client: PrismaClient, toolset_name: str, ) -> Optional[MCPToolset]: - row = await prisma_client.db.litellm_mcptoolsettable.find_first( + row = await MCPToolsetRepository(prisma_client).table.find_first( where={"toolset_name": toolset_name} ) if row is None: @@ -87,7 +88,7 @@ async def update_mcp_toolset( data_dict["tools"] = json.dumps(data_dict["tools"]) data_dict["updated_by"] = touched_by try: - row = await prisma_client.db.litellm_mcptoolsettable.update( + row = await MCPToolsetRepository(prisma_client).table.update( where={"toolset_id": data.toolset_id}, data=data_dict, ) @@ -105,7 +106,7 @@ async def delete_mcp_toolset( toolset_id: str, ) -> Optional[MCPToolset]: try: - row = await prisma_client.db.litellm_mcptoolsettable.delete( + row = await MCPToolsetRepository(prisma_client).table.delete( where={"toolset_id": toolset_id} ) except Exception as e: diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 88be567e59..57a7d860ba 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -23,8 +23,6 @@ from litellm.litellm_core_utils.initialize_dynamic_callback_params import ( from litellm.types.integrations.slack_alerting import AlertType from litellm.types.llms.openai import ( AllMessageValues, - OpenAIFileObject, - ResponsesAPIResponse, ) from litellm.types.mcp import ( MCPAuthType, @@ -41,8 +39,6 @@ from litellm.types.utils import ( EmbeddingResponse, GenericBudgetConfigType, ImageResponse, - LiteLLMBatch, - LiteLLMFineTuningJob, LiteLLMPydanticObjectBase, ModelResponse, ProviderField, @@ -1014,12 +1010,7 @@ class LiteLLM_ObjectPermissionBase(LiteLLMPydanticObjectBase): search_tools: Optional[List[str]] = None -class BudgetLimitEntry(LiteLLMPydanticObjectBase): - """A single budget window with its own limit and independent reset schedule.""" - - budget_duration: str # e.g. "24h", "7d", "30d" - max_budget: float # max spend in USD for this window - reset_at: Optional[datetime] = None # populated at creation/reset time +from litellm.models.team import BudgetLimitEntry as BudgetLimitEntry # noqa: E402 class GenerateRequestBase(LiteLLMPydanticObjectBase): @@ -1217,40 +1208,10 @@ class KeyRequest(LiteLLMPydanticObjectBase): return values -class LiteLLM_ModelTable(LiteLLMPydanticObjectBase): - id: Optional[int] = None - model_aliases: Optional[Union[str, dict]] = None # json dump the dict - created_by: str - updated_by: str - team: Optional["LiteLLM_TeamTable"] = None - - model_config = ConfigDict(protected_namespaces=()) - - -class LiteLLM_ProxyModelTable(LiteLLMPydanticObjectBase): - model_id: str - model_name: str - litellm_params: dict - model_info: dict - created_at: Optional[datetime] = None - created_by: str - updated_at: Optional[datetime] = None - updated_by: str - - @model_validator(mode="before") - @classmethod - def check_potential_json_str(cls, values): - if isinstance(values.get("litellm_params"), str): - try: - values["litellm_params"] = json.loads(values["litellm_params"]) - except json.JSONDecodeError: - pass - if isinstance(values.get("model_info"), str): - try: - values["model_info"] = json.loads(values["model_info"]) - except json.JSONDecodeError: - pass - return values +from litellm.models.model import ( # noqa: E402 + LiteLLM_ProxyModelTable as LiteLLM_ProxyModelTable, +) +from litellm.models.team import LiteLLM_ModelTable as LiteLLM_ModelTable # noqa: E402 # MCP Types @@ -1265,32 +1226,12 @@ class MCPApprovalStatus(str, enum.Enum): rejected = "rejected" -class MCPEnvVarScope(str, enum.Enum): - """Scope for an MCP server environment variable. - - - ``global``: value is provided by the admin and used for all users. - - ``user``: each user must provide their own value via the per-user - env-var endpoint. The admin-supplied ``value`` is treated as a - placeholder/hint and is not used at request time. - """ - - global_ = "global" - user = "user" - - -class MCPEnvVar(LiteLLMPydanticObjectBase): - """One environment variable for an MCP server. - - Variables can be interpolated into ``static_headers`` using ``${NAME}`` - syntax. ``scope=global`` values are stored on the server. ``scope=user`` - values are stored per-user in ``LiteLLM_MCPUserEnvVars`` and supplied by - each user. - """ - - name: str - value: str = "" - scope: MCPEnvVarScope = MCPEnvVarScope.global_ - description: Optional[str] = None +from litellm.models.mcp_server import ( # noqa: E402 + MCPEnvVar as MCPEnvVar, +) +from litellm.models.mcp_server import ( # noqa: E402 + MCPEnvVarScope as MCPEnvVarScope, +) # MCP Proxy Request Types @@ -1443,66 +1384,9 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase): return values -class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase): - """Represents a LiteLLM_MCPServerTable record""" - - server_id: str - server_name: Optional[str] = None - alias: Optional[str] = None - description: Optional[str] = None - url: Optional[str] = None - spec_path: Optional[str] = None - transport: MCPTransportType - auth_type: Optional[MCPAuthType] = None - credentials: Optional[MCPCredentials] = None - instructions: Optional[str] = None - created_at: Optional[datetime] = None - created_by: Optional[str] = None - updated_at: Optional[datetime] = None - updated_by: Optional[str] = None - teams: List[Dict[str, Optional[str]]] = Field(default_factory=list) - mcp_access_groups: List[str] = Field(default_factory=list) - allowed_tools: List[str] = Field(default_factory=list) - tool_name_to_display_name: Optional[Dict[str, str]] = None - tool_name_to_description: Optional[Dict[str, str]] = None - extra_headers: List[str] = Field(default_factory=list) - mcp_info: Optional[MCPInfo] = None - static_headers: Optional[Dict[str, str]] = None - env_vars: Optional[List[MCPEnvVar]] = None - # Health check status - status: Optional[Literal["healthy", "unhealthy", "unknown"]] = Field( - default="unknown", - description="Health status: 'healthy', 'unhealthy', 'unknown'", - ) - last_health_check: Optional[datetime] = None - health_check_error: Optional[str] = None - # Stdio-specific fields - command: Optional[str] = None - args: List[str] = Field(default_factory=list) - env: Dict[str, str] = Field(default_factory=dict) - authorization_url: Optional[str] = None - token_url: Optional[str] = None - registration_url: Optional[str] = None - oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = None - allow_all_keys: bool = False - available_on_public_internet: bool = True - delegate_auth_to_upstream: bool = False - oauth_passthrough: bool = False - is_byok: bool = False - byok_description: List[str] = Field(default_factory=list) - byok_api_key_help_url: Optional[str] = None - has_user_credential: Optional[bool] = None - source_url: Optional[str] = None - timeout: Optional[float] = None - # BYOM submission fields - approval_status: Optional[str] = Field( - default="active", - description="Approval status: 'pending_review', 'active', 'rejected'", - ) - submitted_by: Optional[str] = None - submitted_at: Optional[datetime] = None - reviewed_at: Optional[datetime] = None - review_notes: Optional[str] = None +from litellm.models.mcp_server import ( # noqa: E402 + LiteLLM_MCPServerTable as LiteLLM_MCPServerTable, +) class MakeMCPServersPublicRequest(LiteLLMPydanticObjectBase): @@ -1622,23 +1506,9 @@ class UpdateSkillRequest(LiteLLMPydanticObjectBase): metadata: Optional[Dict[str, Any]] = None -class LiteLLM_SkillsTable(LiteLLMPydanticObjectBase): - """Represents a LiteLLM_SkillsTable record""" - - skill_id: str - display_title: Optional[str] = None - description: Optional[str] = None - instructions: Optional[str] = None - source: str = "custom" - latest_version: Optional[str] = None - file_content: Optional[bytes] = None # Binary content of skill files (zip) - file_name: Optional[str] = None # Original filename - file_type: Optional[str] = None # MIME type - metadata: Optional[Dict[str, Any]] = None - created_at: Optional[datetime] = None - created_by: Optional[str] = None - updated_at: Optional[datetime] = None - updated_by: Optional[str] = None +from litellm.models.skills import ( # noqa: E402 + LiteLLM_SkillsTable as LiteLLM_SkillsTable, +) class ListSkillsRequest(LiteLLMPydanticObjectBase): @@ -1839,33 +1709,8 @@ class DeleteCustomerRequest(LiteLLMPydanticObjectBase): user_ids: List[str] -class MemberBase(LiteLLMPydanticObjectBase): - user_id: Optional[str] = Field( - default=None, - description="The unique ID of the user to add. Either user_id or user_email must be provided", - ) - user_email: Optional[str] = Field( - default=None, - description="The email address of the user to add. Either user_id or user_email must be provided", - ) - - @model_validator(mode="before") - @classmethod - def check_user_info(cls, values): - if not isinstance(values, dict): - raise ValueError("input needs to be a dictionary") - if values.get("user_id") is None and values.get("user_email") is None: - raise ValueError("Either user id or user email must be provided") - return values - - -class Member(MemberBase): - role: Literal[ - "admin", - "user", - ] = Field( - description="The role of the user within the team. 'admin' users can manage team settings and members, 'user' is a regular team member" - ) +from litellm.models.team import Member as Member # noqa: E402 +from litellm.models.team import MemberBase as MemberBase # noqa: E402 class OrgMember(MemberBase): @@ -1876,33 +1721,7 @@ class OrgMember(MemberBase): ] -class TeamBase(LiteLLMPydanticObjectBase): - team_alias: Optional[str] = None - team_id: Optional[str] = None - organization_id: Optional[str] = None - admins: list = [] - members: list = [] - members_with_roles: List[Member] = [] - team_member_permissions: Optional[List[str]] = None - metadata: Optional[dict] = None - tpm_limit: Optional[int] = None - rpm_limit: Optional[int] = None - - # Budget fields - max_budget: Optional[float] = None - soft_budget: Optional[float] = None - budget_duration: Optional[str] = None - budget_limits: Optional[List[BudgetLimitEntry]] = ( - None # multiple concurrent budget windows - ) - - models: list = [] - blocked: bool = False - router_settings: Optional[dict] = None - access_group_ids: Optional[List[str]] = None - default_team_member_models: Optional[List[str]] = ( - None # default allowed_models seeded onto new team members - ) +from litellm.models.team import TeamBase as TeamBase # noqa: E402 class NewTeamRequest(TeamBase): @@ -2100,147 +1919,31 @@ class TeamCallbackMetadata(LiteLLMPydanticObjectBase): return values -class LiteLLM_ObjectPermissionTable(LiteLLMPydanticObjectBase): - """Represents a LiteLLM_ObjectPermissionTable record""" - - object_permission_id: str - mcp_servers: Optional[List[str]] = [] - mcp_access_groups: Optional[List[str]] = [] - mcp_tool_permissions: Optional[Dict[str, List[str]]] = None - """ - Mapping - server_id -> list of tools - - Enforces allowed tools for a specific key/team/organization - { - "1234567890": ["tool_name_1", "tool_name_2"] - } - """ - - vector_stores: Optional[List[str]] = [] - agents: Optional[List[str]] = [] - agent_access_groups: Optional[List[str]] = [] - mcp_toolsets: Optional[List[str]] = None - blocked_tools: Optional[List[str]] = [] - search_tools: Optional[List[str]] = [] - - -class LiteLLM_TeamTable(TeamBase): - team_id: str # type: ignore - spend: Optional[float] = None - max_parallel_requests: Optional[int] = None - budget_duration: Optional[str] = None - budget_reset_at: Optional[datetime] = None - model_id: Optional[int] = None - litellm_model_table: Optional[LiteLLM_ModelTable] = None - object_permission: Optional[LiteLLM_ObjectPermissionTable] = None - updated_at: Optional[datetime] = None - created_at: Optional[datetime] = None - - ######################################################### - # Object Permission - MCP, Vector Stores etc. - ######################################################### - object_permission_id: Optional[str] = None - - model_config = ConfigDict(protected_namespaces=()) - - @model_validator(mode="before") - @classmethod - def set_model_info(cls, values): - dict_fields = [ - "metadata", - "aliases", - "config", - "permissions", - "model_max_budget", - "model_aliases", - "router_settings", - "budget_limits", - ] - - if isinstance(values, BaseModel): - values = values.model_dump() - - if ( - isinstance(values.get("members_with_roles"), dict) - and not values["members_with_roles"] - ): - values["members_with_roles"] = [] - - for field in dict_fields: - value = values.get(field) - if value is not None and isinstance(value, str): - try: - values[field] = json.loads(value) - except json.JSONDecodeError: - raise ValueError(f"Field {field} should be a valid dictionary") - - return values - - -class LiteLLM_TeamTableCachedObj(LiteLLM_TeamTable): - last_refreshed_at: Optional[float] = None - - -class LiteLLM_DeletedTeamTable(LiteLLM_TeamTable): - """ - Recording of deleted teams for audit purposes. Mirrors LiteLLM_TeamTable - plus metadata captured at deletion time. - """ - - id: Optional[str] = None - deleted_at: Optional[datetime] = None - deleted_by: Optional[str] = None - deleted_by_api_key: Optional[str] = None - litellm_changed_by: Optional[str] = None - - model_config = ConfigDict(protected_namespaces=()) +from litellm.models.object_permission import ( # noqa: E402 + LiteLLM_ObjectPermissionTable as LiteLLM_ObjectPermissionTable, +) +from litellm.models.team import ( # noqa: E402 + LiteLLM_DeletedTeamTable as LiteLLM_DeletedTeamTable, +) +from litellm.models.team import LiteLLM_TeamTable as LiteLLM_TeamTable # noqa: E402 +from litellm.models.team import ( # noqa: E402 + LiteLLM_TeamTableCachedObj as LiteLLM_TeamTableCachedObj, +) class TeamRequest(LiteLLMPydanticObjectBase): teams: List[str] -class LiteLLM_BudgetTable(LiteLLMPydanticObjectBase): - """Represents user-controllable params for a LiteLLM_BudgetTable record. - - Budget-write paths use `model_fields.keys()` on this class as an allowlist - for user input. Keep server-managed fields (e.g. `budget_reset_at`) on - `LiteLLM_BudgetTableFull` so they aren't user-settable. - """ - - budget_id: Optional[str] = None - soft_budget: Optional[float] = None - max_budget: Optional[float] = None - max_parallel_requests: Optional[int] = None - tpm_limit: Optional[int] = None - rpm_limit: Optional[int] = None - model_max_budget: Optional[dict] = None - budget_duration: Optional[str] = None - allowed_models: Optional[List[str]] = ( - None # per-member model scope; empty = inherit team models - ) - - model_config = ConfigDict(protected_namespaces=()) - - -class LiteLLM_BudgetTableFull(LiteLLM_BudgetTable): - """LiteLLM_BudgetTable + server-managed fields returned on API responses.""" - - budget_reset_at: Optional[datetime] = None - created_at: datetime - - -class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable): - """ - Used to track spend of a user_id within a team_id - """ - - spend: Optional[float] = None - user_id: Optional[str] = None - team_id: Optional[str] = None - budget_id: Optional[str] = None - - model_config = ConfigDict(protected_namespaces=()) +from litellm.models.budget import ( # noqa: E402 + LiteLLM_BudgetTable as LiteLLM_BudgetTable, +) +from litellm.models.budget import ( # noqa: E402 + LiteLLM_BudgetTableFull as LiteLLM_BudgetTableFull, +) +from litellm.models.budget import ( # noqa: E402 + LiteLLM_TeamMemberTable as LiteLLM_TeamMemberTable, +) class NewOrganizationRequest(LiteLLM_BudgetTable): @@ -2637,66 +2340,12 @@ class ConfigYAML(LiteLLMPydanticObjectBase): model_config = ConfigDict(protected_namespaces=()) -class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase): - token: Optional[str] = None - key_name: Optional[str] = None - key_alias: Optional[str] = None - spend: float = 0.0 - max_budget: Optional[float] = None - expires: Optional[Union[str, datetime]] = None - models: List = [] - aliases: Dict = {} - config: Dict = {} - user_id: Optional[str] = None - team_id: Optional[str] = None - agent_id: Optional[str] = None - project_id: Optional[str] = None - max_parallel_requests: Optional[int] = None - metadata: Dict = {} - tpm_limit: Optional[int] = None - rpm_limit: Optional[int] = None - budget_duration: Optional[str] = None - budget_reset_at: Optional[datetime] = None - allowed_cache_controls: Optional[list] = [] - allowed_routes: Optional[list] = [] - permissions: Dict = {} - model_spend: Dict = {} - model_max_budget: Dict = {} - soft_budget_cooldown: bool = False - blocked: Optional[bool] = None - litellm_budget_table: Optional[dict] = None - org_id: Optional[str] = None # org id for a given key - created_at: Optional[datetime] = None - created_by: Optional[str] = None - updated_at: Optional[datetime] = None - updated_by: Optional[str] = None - last_active: Optional[datetime] = None - object_permission_id: Optional[str] = None - object_permission: Optional[LiteLLM_ObjectPermissionTable] = None - access_group_ids: Optional[List[str]] = None - rotation_count: Optional[int] = 0 # Number of times key has been rotated - auto_rotate: Optional[bool] = False # Whether this key should be auto-rotated - rotation_interval: Optional[str] = None # How often to rotate (e.g., "30d", "90d") - last_rotation_at: Optional[datetime] = None # When this key was last rotated - key_rotation_at: Optional[datetime] = None # When this key should next be rotated - router_settings: Optional[dict] = None - budget_limits: Optional[List[dict]] = None # multiple concurrent budget windows - model_config = ConfigDict(protected_namespaces=()) - - -class LiteLLM_DeletedVerificationToken(LiteLLM_VerificationToken): - """ - Recording of deleted keys for audit purposes. Mirrors LiteLLM_VerificationToken - plus metadata captured at deletion time. - """ - - id: Optional[str] = None - deleted_at: Optional[datetime] = None - deleted_by: Optional[str] = None - deleted_by_api_key: Optional[str] = None - litellm_changed_by: Optional[str] = None - - model_config = ConfigDict(protected_namespaces=()) +from litellm.models.verification_token import ( # noqa: E402 + LiteLLM_DeletedVerificationToken as LiteLLM_DeletedVerificationToken, +) +from litellm.models.verification_token import ( # noqa: E402 + LiteLLM_VerificationToken as LiteLLM_VerificationToken, +) class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): @@ -2935,39 +2584,10 @@ class UserInfoV2Response(LiteLLMPydanticObjectBase): teams: List[str] = [] # Just team IDs, not full team objects -class LiteLLM_Config(LiteLLMPydanticObjectBase): - param_name: str - param_value: Dict - - -class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): - """ - This is the table that track what organizations a user belongs to and users spend within the organization - """ - - user_id: str - organization_id: str - user_role: Optional[str] = None - spend: float = 0.0 - budget_id: Optional[str] = None - created_at: datetime - updated_at: datetime - user: Optional[Any] = ( - None # You might want to replace 'Any' with a more specific type if available - ) - litellm_budget_table: Optional[LiteLLM_BudgetTable] = None - user_email: Optional[str] = None - - model_config = ConfigDict(protected_namespaces=()) - - @model_validator(mode="after") - def populate_user_email(self) -> "LiteLLM_OrganizationMembershipTable": - if self.user_email is None and self.user is not None: - if isinstance(self.user, dict): - self.user_email = self.user.get("user_email") - else: - self.user_email = getattr(self.user, "user_email", None) - return self +from litellm.models.config import LiteLLM_Config as LiteLLM_Config # noqa: E402 +from litellm.models.organization_membership import ( # noqa: E402 + LiteLLM_OrganizationMembershipTable as LiteLLM_OrganizationMembershipTable, +) class LiteLLM_OrganizationTableUpdate(LiteLLM_BudgetTable): @@ -2997,61 +2617,10 @@ class LiteLLM_OrganizationTableUpdate(LiteLLM_BudgetTable): return values -class LiteLLM_UserTable(LiteLLMPydanticObjectBase): - user_id: str - max_budget: Optional[float] = None - spend: float = 0.0 - model_max_budget: Optional[Dict] = {} - model_spend: Optional[Dict] = {} - user_email: Optional[str] = None - user_alias: Optional[str] = None - models: list = [] - tpm_limit: Optional[int] = None - rpm_limit: Optional[int] = None - user_role: Optional[str] = None - organization_memberships: Optional[List[LiteLLM_OrganizationMembershipTable]] = None - teams: List[str] = [] - sso_user_id: Optional[str] = None - budget_duration: Optional[str] = None - budget_reset_at: Optional[datetime] = None - metadata: Optional[dict] = None - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - object_permission: Optional[LiteLLM_ObjectPermissionTable] = None - - @model_validator(mode="before") - @classmethod - def set_model_info(cls, values): - if values.get("spend") is None: - values.update({"spend": 0.0}) - if values.get("models") is None: - values.update({"models": []}) - if values.get("teams") is None: - values.update({"teams": []}) - return values - - model_config = ConfigDict(protected_namespaces=()) - - -class LiteLLM_OrganizationTable(LiteLLMPydanticObjectBase): - """Represents user-controllable params for a LiteLLM_OrganizationTable record""" - - organization_id: Optional[str] = None - organization_alias: Optional[str] = None - budget_id: str - spend: float = 0.0 - metadata: Optional[dict] = None - models: List[str] - created_by: str - updated_by: str - users: Optional[List[LiteLLM_UserTable]] = None - litellm_budget_table: Optional[LiteLLM_BudgetTable] = None - - ######################################################### - # Object Permission - MCP, Vector Stores etc. - ######################################################### - object_permission: Optional[LiteLLM_ObjectPermissionTable] = None - object_permission_id: Optional[str] = None +from litellm.models.organization import ( # noqa: E402 + LiteLLM_OrganizationTable as LiteLLM_OrganizationTable, +) +from litellm.models.user import LiteLLM_UserTable as LiteLLM_UserTable # noqa: E402 class LiteLLM_OrganizationTableWithMembers(LiteLLM_OrganizationTable): @@ -3160,28 +2729,9 @@ class DeleteProjectRequest(LiteLLMPydanticObjectBase): project_ids: List[str] -class LiteLLM_ProjectTable(LiteLLMPydanticObjectBase): - """Database model representation for project""" - - project_id: str - project_alias: Optional[str] = None - description: Optional[str] = None - team_id: Optional[str] = None - budget_id: Optional[str] = None - metadata: Optional[dict] = None - models: List[str] = [] - spend: float = 0.0 - model_spend: Optional[dict] = None - model_rpm_limit: Optional[dict] = None - model_tpm_limit: Optional[dict] = None - blocked: bool = False - object_permission_id: Optional[str] = None - created_by: str - updated_by: str - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - litellm_budget_table: Optional[LiteLLM_BudgetTable] = None - object_permission: Optional[LiteLLM_ObjectPermissionTable] = None +from litellm.models.project import ( # noqa: E402 + LiteLLM_ProjectTable as LiteLLM_ProjectTable, +) class NewProjectResponse(LiteLLM_ProjectTable): @@ -3207,101 +2757,19 @@ class LiteLLM_UserTableWithKeyCount(LiteLLM_UserTable): key_count: int = 0 -class LiteLLM_EndUserTable(LiteLLMPydanticObjectBase): - user_id: str - blocked: bool - alias: Optional[str] = None - spend: float = 0.0 - allowed_model_region: Optional[AllowedModelRegion] = None - default_model: Optional[str] = None - litellm_budget_table: Optional[LiteLLM_BudgetTable] = None - object_permission_id: Optional[str] = None - object_permission: Optional[LiteLLM_ObjectPermissionTable] = None - - @model_validator(mode="before") - @classmethod - def set_model_info(cls, values): - if values.get("spend") is None: - values.update({"spend": 0.0}) - return values - - model_config = ConfigDict(protected_namespaces=()) - - -class LiteLLM_TagTable(LiteLLMPydanticObjectBase): - tag_name: str - description: Optional[str] = None - models: List[str] = [] - model_info: Optional[dict] = None - spend: float = 0.0 - budget_id: Optional[str] = None - litellm_budget_table: Optional[LiteLLM_BudgetTable] = None - created_at: Optional[datetime] = None - created_by: Optional[str] = None - updated_at: Optional[datetime] = None - - @model_validator(mode="before") - @classmethod - def set_model_info(cls, values): - if values.get("spend") is None: - values.update({"spend": 0.0}) - if values.get("models") is None: - values.update({"models": []}) - return values - - model_config = ConfigDict(protected_namespaces=()) - - -class LiteLLM_AccessGroupTable(LiteLLMPydanticObjectBase): - access_group_id: str - access_group_name: str - description: Optional[str] = None - access_model_names: List[str] = [] - access_mcp_server_ids: List[str] = [] - access_agent_ids: List[str] = [] - assigned_team_ids: List[str] = [] - assigned_key_ids: List[str] = [] - created_at: Optional[datetime] = None - created_by: Optional[str] = None - updated_at: Optional[datetime] = None - updated_by: Optional[str] = None - - -class LiteLLM_SpendLogs(LiteLLMPydanticObjectBase): - request_id: str - api_key: str - model: Optional[str] = "" - api_base: Optional[str] = "" - call_type: str - spend: Optional[float] = 0.0 - total_tokens: Optional[int] = 0 - prompt_tokens: Optional[int] = 0 - completion_tokens: Optional[int] = 0 - startTime: Union[str, datetime, None] - endTime: Union[str, datetime, None] - user: Optional[str] = "" - metadata: Optional[Json] = {} - cache_hit: Optional[str] = "False" - cache_key: Optional[str] = None - request_tags: Optional[Json] = None - requester_ip_address: Optional[str] = None - messages: Optional[Union[str, list, dict]] - response: Optional[Union[str, list, dict]] - - -class LiteLLM_ErrorLogs(LiteLLMPydanticObjectBase): - request_id: Optional[str] = str(uuid.uuid4()) - api_base: Optional[str] = "" - model_group: Optional[str] = "" - litellm_model_name: Optional[str] = "" - model_id: Optional[str] = "" - request_kwargs: Optional[dict] = {} - exception_type: Optional[str] = "" - status_code: Optional[str] = "" - exception_string: Optional[str] = "" - startTime: Union[str, datetime, None] - endTime: Union[str, datetime, None] - +from litellm.models.access_group import ( # noqa: E402 + LiteLLM_AccessGroupTable as LiteLLM_AccessGroupTable, +) +from litellm.models.end_user import ( # noqa: E402 + LiteLLM_EndUserTable as LiteLLM_EndUserTable, +) +from litellm.models.spend_logs import ( # noqa: E402 + LiteLLM_ErrorLogs as LiteLLM_ErrorLogs, +) +from litellm.models.spend_logs import ( # noqa: E402 + LiteLLM_SpendLogs as LiteLLM_SpendLogs, +) +from litellm.models.tag import LiteLLM_TagTable as LiteLLM_TagTable # noqa: E402 AUDIT_ACTIONS = Literal[ "created", "updated", "deleted", "blocked", "unblocked", "rotated" @@ -3982,29 +3450,9 @@ class CreatePassThroughEndpoint(LiteLLMPydanticObjectBase): headers: dict -class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase): - user_id: str - team_id: str - budget_id: Optional[str] = None - spend: Optional[float] = 0.0 - total_spend: Optional[float] = 0.0 - # Union so Pydantic picks Full when data has server-managed fields - # (/team/info) and Base when callers/tests construct with only - # user-settable fields. - litellm_budget_table: Optional[ - Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable] - ] = None - - def safe_get_team_member_rpm_limit(self) -> Optional[int]: - if self.litellm_budget_table is not None: - return self.litellm_budget_table.rpm_limit - return None - - def safe_get_team_member_tpm_limit(self) -> Optional[int]: - if self.litellm_budget_table is not None: - return self.litellm_budget_table.tpm_limit - return None - +from litellm.models.team_membership import ( # noqa: E402 + LiteLLM_TeamMembership as LiteLLM_TeamMembership, +) #### Organization / Team Member Requests #### @@ -4922,39 +4370,18 @@ class ToolDiscoveryQueueItem(TypedDict, total=False): user_agent: Optional[str] # HTTP User-Agent of the caller -class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase): - unified_file_id: str - file_object: Optional[OpenAIFileObject] = None - model_mappings: Dict[str, str] - flat_model_file_ids: List[str] - created_by: Optional[str] = None - team_id: Optional[str] = None - updated_by: Optional[str] = None - storage_backend: Optional[str] = None - storage_url: Optional[str] = None - - -class LiteLLM_ManagedObjectTable(LiteLLMPydanticObjectBase): - unified_object_id: str - model_object_id: str - file_purpose: Literal["batch", "fine-tune", "response", "container"] - file_object: Union[LiteLLMBatch, LiteLLMFineTuningJob, ResponsesAPIResponse] - created_by: Optional[str] = None - team_id: Optional[str] = None - - -class LiteLLM_ManagedVectorStoreTable(LiteLLMPydanticObjectBase): - """Table for managing vector stores with target_model_names support.""" - - unified_resource_id: str - resource_object: Optional[Any] = None # VectorStoreCreateResponse - model_mappings: Dict[str, str] - flat_model_resource_ids: List[str] - created_by: Optional[str] = None - team_id: Optional[str] = None - updated_by: Optional[str] = None - storage_backend: Optional[str] = None - storage_url: Optional[str] = None +from litellm.models.managed_files import ( # noqa: E402 + LiteLLM_ManagedFileTable as LiteLLM_ManagedFileTable, +) +from litellm.models.managed_files import ( # noqa: E402 + LiteLLM_ManagedObjectTable as LiteLLM_ManagedObjectTable, +) +from litellm.models.managed_files import ( # noqa: E402 + LiteLLM_ManagedVectorStoresTable as LiteLLM_ManagedVectorStoresTable, +) +from litellm.models.managed_files import ( # noqa: E402 + LiteLLM_ManagedVectorStoreTable as LiteLLM_ManagedVectorStoreTable, +) class EnterpriseLicenseData(TypedDict, total=False): @@ -4965,20 +4392,6 @@ class EnterpriseLicenseData(TypedDict, total=False): max_teams: int -class LiteLLM_ManagedVectorStoresTable(LiteLLMPydanticObjectBase): - vector_store_id: str - custom_llm_provider: str - vector_store_name: Optional[str] - vector_store_description: Optional[str] - vector_store_metadata: Optional[Dict[str, Any]] - created_at: Optional[datetime] - updated_at: Optional[datetime] - litellm_credential_name: Optional[str] - litellm_params: Optional[Dict[str, Any]] - team_id: Optional[str] - user_id: Optional[str] - - class ResponseLiteLLM_ManagedVectorStore(TypedDict, total=False): vector_store: LiteLLM_ManagedVectorStoresTable diff --git a/litellm/proxy/agent_endpoints/agent_registry.py b/litellm/proxy/agent_endpoints/agent_registry.py index 13a2dd9f04..11fd01e236 100644 --- a/litellm/proxy/agent_endpoints/agent_registry.py +++ b/litellm/proxy/agent_endpoints/agent_registry.py @@ -9,6 +9,7 @@ from litellm.proxy.management_helpers.object_permission_utils import ( handle_update_object_permission_common, ) from litellm.proxy.utils import PrismaClient +from litellm.repositories.table_repositories import AgentsRepository from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest @@ -174,7 +175,7 @@ class AgentRegistry: create_data[rate_field] = _val # Create agent in DB - created_agent = await prisma_client.db.litellm_agentstable.create( + created_agent = await AgentsRepository(prisma_client).table.create( data=create_data, include={"object_permission": True}, ) @@ -200,7 +201,7 @@ class AgentRegistry: Delete an agent from the database """ try: - deleted_agent = await prisma_client.db.litellm_agentstable.delete( + deleted_agent = await AgentsRepository(prisma_client).table.delete( where={"agent_id": agent_id} ) return dict(deleted_agent) @@ -229,7 +230,7 @@ class AgentRegistry: The patched agent """ try: - existing_agent = await prisma_client.db.litellm_agentstable.find_unique( + existing_agent = await AgentsRepository(prisma_client).table.find_unique( where={"agent_id": agent_id} ) if existing_agent is not None: @@ -282,7 +283,7 @@ class AgentRegistry: if object_permission_id is not None: update_data["object_permission_id"] = object_permission_id # Patch agent in DB - patched_agent = await prisma_client.db.litellm_agentstable.update( + patched_agent = await AgentsRepository(prisma_client).table.update( where={"agent_id": agent_id}, data={ **update_data, @@ -368,9 +369,9 @@ class AgentRegistry: update_data[rate_field] = _val if agent.get("object_permission") is not None: - existing_agent = await prisma_client.db.litellm_agentstable.find_unique( - where={"agent_id": agent_id} - ) + existing_agent = await AgentsRepository( + prisma_client + ).table.find_unique(where={"agent_id": agent_id}) existing_object_permission_id = ( existing_agent.object_permission_id if existing_agent is not None @@ -386,7 +387,7 @@ class AgentRegistry: update_data["object_permission_id"] = object_permission_id # Update agent in DB - updated_agent = await prisma_client.db.litellm_agentstable.update( + updated_agent = await AgentsRepository(prisma_client).table.update( where={"agent_id": agent_id}, data=update_data, include={"object_permission": True}, @@ -414,7 +415,7 @@ class AgentRegistry: Get all agents from the database """ try: - agents_from_db = await prisma_client.db.litellm_agentstable.find_many( + agents_from_db = await AgentsRepository(prisma_client).table.find_many( order={"created_at": "desc"}, include={"object_permission": True}, ) diff --git a/litellm/proxy/agent_endpoints/auth/agent_permission_handler.py b/litellm/proxy/agent_endpoints/auth/agent_permission_handler.py index 42cf31e1e2..2577615fc8 100644 --- a/litellm/proxy/agent_endpoints/auth/agent_permission_handler.py +++ b/litellm/proxy/agent_endpoints/auth/agent_permission_handler.py @@ -9,11 +9,12 @@ from typing import List, Optional, Set from litellm._logging import verbose_logger from litellm.proxy._types import ( + UI_TEAM_ID, LiteLLM_ObjectPermissionTable, LiteLLM_TeamTable, - UI_TEAM_ID, UserAPIKeyAuth, ) +from litellm.repositories.table_repositories import AgentsRepository class AgentRequestHandler: @@ -298,7 +299,7 @@ class AgentRequestHandler: agent_ids: Set[str] = set() if access_groups and prisma_client is not None: try: - agents = await prisma_client.db.litellm_agentstable.find_many( + agents = await AgentsRepository(prisma_client).table.find_many( where={"agent_access_groups": {"hasSome": access_groups}} ) for agent in agents: diff --git a/litellm/proxy/agent_endpoints/endpoints.py b/litellm/proxy/agent_endpoints/endpoints.py index 19dbfe33d3..d19008856b 100644 --- a/litellm/proxy/agent_endpoints/endpoints.py +++ b/litellm/proxy/agent_endpoints/endpoints.py @@ -17,6 +17,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request import litellm from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import _get_masked_values from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.a2a.agent_card import merge_agent_card @@ -30,7 +31,6 @@ from litellm.types.agents import ( MakeAgentsPublicRequest, PatchAgentRequest, ) -from litellm.litellm_core_utils.litellm_logging import _get_masked_values from litellm.types.llms.custom_http import httpxSpecialProvider from litellm.types.proxy.management_endpoints.common_daily_activity import ( DailySpendMetadata, @@ -219,7 +219,7 @@ async def get_agents( if prisma_client is not None: agent_ids = [agent.agent_id for agent in returned_agents] if agent_ids: - db_agents = await prisma_client.db.litellm_agentstable.find_many( + db_agents = await AgentsRepository(prisma_client).table.find_many( where={"agent_id": {"in": agent_ids}}, ) spend_map = {a.agent_id: a.spend for a in db_agents} @@ -301,6 +301,7 @@ async def get_agents( from litellm.proxy.agent_endpoints.agent_registry import ( global_agent_registry as AGENT_REGISTRY, ) +from litellm.repositories.table_repositories import AgentsRepository @router.post( @@ -471,7 +472,7 @@ async def get_agent_by_id( try: agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id) if agent is None: - agent_row = await prisma_client.db.litellm_agentstable.find_unique( + agent_row = await AgentsRepository(prisma_client).table.find_unique( where={"agent_id": agent_id}, include={"object_permission": True}, ) @@ -489,7 +490,7 @@ async def get_agent_by_id( agent = AgentResponse(**agent_dict) # type: ignore else: # Agent found in memory — refresh spend from DB - db_row = await prisma_client.db.litellm_agentstable.find_unique( + db_row = await AgentsRepository(prisma_client).table.find_unique( where={"agent_id": agent_id} ) if db_row is not None: @@ -570,7 +571,7 @@ async def update_agent( try: # Check if agent exists - existing_agent = await prisma_client.db.litellm_agentstable.find_unique( + existing_agent = await AgentsRepository(prisma_client).table.find_unique( where={"agent_id": agent_id} ) if existing_agent is not None: @@ -678,7 +679,7 @@ async def patch_agent( try: # Check if agent exists - existing_agent = await prisma_client.db.litellm_agentstable.find_unique( + existing_agent = await AgentsRepository(prisma_client).table.find_unique( where={"agent_id": agent_id} ) if existing_agent is not None: @@ -769,7 +770,7 @@ async def delete_agent( try: # Check if agent exists - existing_agent = await prisma_client.db.litellm_agentstable.find_unique( + existing_agent = await AgentsRepository(prisma_client).table.find_unique( where={"agent_id": agent_id} ) if existing_agent is not None: @@ -859,7 +860,7 @@ async def make_agent_public( agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id) if agent is None: # check if agent exists in DB - agent = await prisma_client.db.litellm_agentstable.find_unique( + agent = await AgentsRepository(prisma_client).table.find_unique( where={"agent_id": agent_id} ) if agent is not None: @@ -982,7 +983,7 @@ async def make_agents_public( agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id) if agent is None: # check if agent exists in DB - agent = await prisma_client.db.litellm_agentstable.find_unique( + agent = await AgentsRepository(prisma_client).table.find_unique( where={"agent_id": agent_id} ) if agent is not None: @@ -1082,7 +1083,7 @@ async def get_agent_daily_activity( if user_api_key_dict.user_id is None: permitted_agent_ids = [] else: - owned_records = await prisma_client.db.litellm_agentstable.find_many( + owned_records = await AgentsRepository(prisma_client).table.find_many( where={"created_by": user_api_key_dict.user_id} ) permitted_agent_ids = [a.agent_id for a in owned_records] @@ -1118,7 +1119,7 @@ async def get_agent_daily_activity( if agent_ids_list: where_condition["agent_id"] = {"in": list(agent_ids_list)} - agent_records = await prisma_client.db.litellm_agentstable.find_many( + agent_records = await AgentsRepository(prisma_client).table.find_many( where=where_condition ) agent_metadata = { diff --git a/litellm/proxy/anthropic_endpoints/claude_code_endpoints/claude_code_marketplace.py b/litellm/proxy/anthropic_endpoints/claude_code_endpoints/claude_code_marketplace.py index 20b1659fa1..dd7350e13c 100644 --- a/litellm/proxy/anthropic_endpoints/claude_code_endpoints/claude_code_marketplace.py +++ b/litellm/proxy/anthropic_endpoints/claude_code_endpoints/claude_code_marketplace.py @@ -26,6 +26,7 @@ from fastapi.responses import JSONResponse from litellm._logging import verbose_proxy_logger from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.repositories.table_repositories import ClaudeCodePluginRepository from litellm.types.proxy.claude_code_endpoints import ( ListPluginsResponse, PluginListItem, @@ -71,7 +72,7 @@ async def get_marketplace(): try: prisma_client = await _get_prisma_client() - plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many( + plugins = await ClaudeCodePluginRepository(prisma_client).table.find_many( where={"enabled": True} ) @@ -268,12 +269,12 @@ async def register_plugin( manifest["namespace"] = request.namespace # Check if plugin exists - existing = await prisma_client.db.litellm_claudecodeplugintable.find_unique( + existing = await ClaudeCodePluginRepository(prisma_client).table.find_unique( where={"name": request.name} ) if existing: - plugin = await prisma_client.db.litellm_claudecodeplugintable.update( + plugin = await ClaudeCodePluginRepository(prisma_client).table.update( where={"name": request.name}, data={ "version": request.version, @@ -285,7 +286,7 @@ async def register_plugin( ) action = "updated" else: - plugin = await prisma_client.db.litellm_claudecodeplugintable.create( + plugin = await ClaudeCodePluginRepository(prisma_client).table.create( data={ "name": request.name, "version": request.version, @@ -348,7 +349,7 @@ async def list_plugins( prisma_client = await _get_prisma_client() where = {"enabled": True} if enabled_only else {} - plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many( + plugins = await ClaudeCodePluginRepository(prisma_client).table.find_many( where=where ) @@ -415,7 +416,7 @@ async def get_plugin( try: prisma_client = await _get_prisma_client() - plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique( + plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique( where={"name": plugin_name} ) @@ -471,7 +472,7 @@ async def enable_plugin( try: prisma_client = await _get_prisma_client() - plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique( + plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique( where={"name": plugin_name} ) if not plugin: @@ -480,7 +481,7 @@ async def enable_plugin( detail={"error": f"Plugin '{plugin_name}' not found"}, ) - await prisma_client.db.litellm_claudecodeplugintable.update( + await ClaudeCodePluginRepository(prisma_client).table.update( where={"name": plugin_name}, data={"enabled": True, "updated_at": datetime.now(timezone.utc)}, ) @@ -516,7 +517,7 @@ async def disable_plugin( try: prisma_client = await _get_prisma_client() - plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique( + plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique( where={"name": plugin_name} ) if not plugin: @@ -525,7 +526,7 @@ async def disable_plugin( detail={"error": f"Plugin '{plugin_name}' not found"}, ) - await prisma_client.db.litellm_claudecodeplugintable.update( + await ClaudeCodePluginRepository(prisma_client).table.update( where={"name": plugin_name}, data={"enabled": False, "updated_at": datetime.now(timezone.utc)}, ) @@ -561,7 +562,7 @@ async def delete_plugin( try: prisma_client = await _get_prisma_client() - plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique( + plugin = await ClaudeCodePluginRepository(prisma_client).table.find_unique( where={"name": plugin_name} ) if not plugin: @@ -570,7 +571,7 @@ async def delete_plugin( detail={"error": f"Plugin '{plugin_name}' not found"}, ) - await prisma_client.db.litellm_claudecodeplugintable.delete( + await ClaudeCodePluginRepository(prisma_client).table.delete( where={"name": plugin_name} ) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 94ae3f5eac..45007861d5 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -61,19 +61,33 @@ from litellm.proxy._types import ( UserAPIKeyAuth, ) from litellm.proxy.auth.route_checks import RouteChecks +from litellm.proxy.common_utils.cache_pydantic_utils import CacheCodec from litellm.proxy.common_utils.http_parsing_utils import ( _safe_get_request_headers, _safe_get_request_query_params, ) +from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler from litellm.proxy.guardrails.tool_name_extraction import ( TOOL_CAPABLE_CALL_TYPES, extract_request_tool_names, ) -from litellm.proxy.common_utils.cache_pydantic_utils import CacheCodec -from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.route_llm_request import route_request from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics +from litellm.repositories.budget_repository import BudgetRepository +from litellm.repositories.object_permission_repository import ObjectPermissionRepository +from litellm.repositories.organization_repository import OrganizationRepository +from litellm.repositories.project_repository import ProjectRepository +from litellm.repositories.table_repositories import ( + AccessGroupRepository, + EndUserRepository, + JWTKeyMappingRepository, + ManagedVectorStoresRepository, + TagRepository, + TeamMembershipRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository from litellm.router import Router from litellm.utils import get_utc_datetime @@ -957,7 +971,7 @@ async def get_default_end_user_budget( # Fetch from database try: - budget_record = await prisma_client.db.litellm_budgettable.find_unique( + budget_record = await BudgetRepository(prisma_client).table.find_unique( where={"budget_id": litellm.max_end_user_budget_id} ) @@ -1016,7 +1030,7 @@ async def get_team_member_default_budget( return LiteLLM_BudgetTable(**cached_budget) try: - budget_record = await prisma_client.db.litellm_budgettable.find_unique( + budget_record = await BudgetRepository(prisma_client).table.find_unique( where={"budget_id": budget_id} ) @@ -1175,7 +1189,7 @@ async def get_end_user_object( # Fetch from database try: - response = await prisma_client.db.litellm_endusertable.find_unique( + response = await EndUserRepository(prisma_client).table.find_unique( where={"user_id": end_user_id}, include={"litellm_budget_table": True, "object_permission": True}, ) @@ -1375,7 +1389,7 @@ async def get_tag_objects_batch( # Batch fetch uncached tags from DB in one query if uncached_tags: try: - db_tags = await prisma_client.db.litellm_tagtable.find_many( + db_tags = await TagRepository(prisma_client).table.find_many( where={"tag_name": {"in": uncached_tags}}, include={"litellm_budget_table": True}, ) @@ -1469,7 +1483,7 @@ async def get_team_membership( # else, check db try: - response = await prisma_client.db.litellm_teammembership.find_unique( + response = await TeamMembershipRepository(prisma_client).table.find_unique( where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}}, include={"litellm_budget_table": True}, ) @@ -1622,7 +1636,7 @@ async def _get_fuzzy_user_object( response = None if sso_user_id is not None: - response = await prisma_client.db.litellm_usertable.find_unique( + response = await UserRepository(prisma_client).table.find_unique( where={"sso_user_id": sso_user_id}, include={"organization_memberships": True}, ) @@ -1630,14 +1644,14 @@ async def _get_fuzzy_user_object( if response is None and user_email is not None: # Use case-insensitive query to handle emails with different casing # This matches the pattern used in _check_duplicate_user_email - response = await prisma_client.db.litellm_usertable.find_first( + response = await UserRepository(prisma_client).table.find_first( where={"user_email": {"equals": user_email, "mode": "insensitive"}}, include={"organization_memberships": True}, ) if response is not None and sso_user_id is not None: # update sso_user_id asyncio.create_task( # background task to update user with sso id - prisma_client.db.litellm_usertable.update( + UserRepository(prisma_client).table.update( where={"user_id": response.user_id}, data={"sso_user_id": sso_user_id}, ) @@ -1687,7 +1701,7 @@ async def get_user_object( ) if should_check_db: - response = await prisma_client.db.litellm_usertable.find_unique( + response = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_id}, include={"organization_memberships": True} ) @@ -1711,7 +1725,7 @@ async def get_user_object( if litellm.default_internal_user_params is not None: new_user_params.update(litellm.default_internal_user_params) - response = await prisma_client.db.litellm_usertable.create( + response = await UserRepository(prisma_client).table.create( data=new_user_params, include={"organization_memberships": True}, ) @@ -1860,7 +1874,7 @@ async def _delete_cache_key_object( async def _get_team_db_check( team_id: str, prisma_client: PrismaClient, team_id_upsert: Optional[bool] = None ): - response = await prisma_client.db.litellm_teamtable.find_unique( + response = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) @@ -1882,7 +1896,7 @@ async def _get_team_db_check( async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient): - return await prisma_client.db.litellm_teamtable.find_unique( + return await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) @@ -2111,7 +2125,7 @@ async def get_access_object( # Not in cache - fetch from DB try: - response = await prisma_client.db.litellm_accessgrouptable.find_unique( + response = await AccessGroupRepository(prisma_client).table.find_unique( where={"access_group_id": access_group_id} ) @@ -2193,7 +2207,7 @@ async def get_team_object_by_alias( # Query database by team_alias try: - teams = await prisma_client.db.litellm_teamtable.find_many( + teams = await TeamRepository(prisma_client).table.find_many( where={"team_alias": team_alias} ) @@ -2301,7 +2315,7 @@ async def get_org_object_by_alias( # Query database by organization_alias try: - orgs = await prisma_client.db.litellm_organizationtable.find_many( + orgs = await OrganizationRepository(prisma_client).table.find_many( where={"organization_alias": org_alias} ) @@ -2526,7 +2540,7 @@ async def get_jwt_key_mapping_object( Returns the hashed token (str) if a matching active mapping is found, else None. """ - mapping = await prisma_client.db.litellm_jwtkeymapping.find_first( + mapping = await JWTKeyMappingRepository(prisma_client).table.find_first( where={ "jwt_claim_name": jwt_claim_name, "jwt_claim_value": jwt_claim_value, @@ -2659,7 +2673,7 @@ async def get_object_permission( # else, check db try: - response = await prisma_client.db.litellm_objectpermissiontable.find_unique( + response = await ObjectPermissionRepository(prisma_client).table.find_unique( where={"object_permission_id": object_permission_id} ) @@ -2715,7 +2729,7 @@ async def get_managed_vector_store_rows_by_uuids( if not cache_misses: return result - rows = await prisma_client.db.litellm_managedvectorstorestable.find_many( + rows = await ManagedVectorStoresRepository(prisma_client).table.find_many( where={"vector_store_id": {"in": cache_misses}}, take=len(cache_misses), ) @@ -2790,7 +2804,7 @@ async def get_org_object( if include_budget_table: query_kwargs["include"] = {"litellm_budget_table": True} - response = await prisma_client.db.litellm_organizationtable.find_unique( + response = await OrganizationRepository(prisma_client).table.find_unique( **query_kwargs ) @@ -4180,7 +4194,7 @@ async def get_project_object( return deserialized_project # Fetch from DB - project_row = await prisma_client.db.litellm_projecttable.find_unique( + project_row = await ProjectRepository(prisma_client).table.find_unique( where={"project_id": project_id}, include={"litellm_budget_table": True}, ) @@ -4480,10 +4494,10 @@ async def vector_store_access_check( ######################################################### # Check if the key can access the vector store if valid_token is not None and valid_token.object_permission_id is not None: - key_object_permission = ( - await prisma_client.db.litellm_objectpermissiontable.find_unique( - where={"object_permission_id": valid_token.object_permission_id}, - ) + key_object_permission = await ObjectPermissionRepository( + prisma_client + ).table.find_unique( + where={"object_permission_id": valid_token.object_permission_id}, ) if key_object_permission is not None: _can_object_call_vector_stores( @@ -4494,10 +4508,10 @@ async def vector_store_access_check( # Check if the team can access the vector store if team_object is not None and team_object.object_permission_id is not None: - team_object_permission = ( - await prisma_client.db.litellm_objectpermissiontable.find_unique( - where={"object_permission_id": team_object.object_permission_id}, - ) + team_object_permission = await ObjectPermissionRepository( + prisma_client + ).table.find_unique( + where={"object_permission_id": team_object.object_permission_id}, ) if team_object_permission is not None: _can_object_call_vector_stores( diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 6d3d49b71e..536d286785 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -14,11 +14,11 @@ import os import re from typing import Any, List, Literal, Optional, Set, Tuple, Union, cast +import jwt from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from fastapi import HTTPException, status -import jwt from jwt.api_jwk import PyJWK from litellm._logging import verbose_proxy_logger @@ -50,6 +50,7 @@ from litellm.proxy.auth.auth_checks import can_team_access_model from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.utils import PrismaClient, ProxyLogging +from litellm.repositories.user_repository import UserRepository from .auth_checks import ( _allowed_routes_check, @@ -1790,7 +1791,7 @@ class JWTAuthManager: # Update user role new_role = jwt_handler.map_jwt_role_to_litellm_role(jwt_valid_token) if new_role and user_object.user_role != new_role.value: - await prisma_client.db.litellm_usertable.update( + await UserRepository(prisma_client).table.update( where={"user_id": user_object.user_id}, data={"user_role": new_role.value}, ) diff --git a/litellm/proxy/auth/login_utils.py b/litellm/proxy/auth/login_utils.py index 34085d5685..d0818b9536 100644 --- a/litellm/proxy/auth/login_utils.py +++ b/litellm/proxy/auth/login_utils.py @@ -34,6 +34,7 @@ from litellm.proxy.utils import ( hash_password, verify_password, ) +from litellm.repositories.user_repository import UserRepository from litellm.secret_managers.main import get_secret_bool from litellm.types.proxy.ui_sso import ReturnedUITokenObject @@ -45,7 +46,7 @@ async def _rehash_password_if_needed(user_id: str, password: str, stored: str) - from litellm.proxy.proxy_server import prisma_client if prisma_client is not None: - await prisma_client.db.litellm_usertable.update( + await UserRepository(prisma_client).table.update( where={"user_id": user_id}, data={"password": hash_password(password)}, ) @@ -151,7 +152,7 @@ async def authenticate_user( # noqa: PLR0915 if prisma_client is not None: _user_row = cast( Optional[LiteLLM_UserTable], - await prisma_client.db.litellm_usertable.find_first( + await UserRepository(prisma_client).table.find_first( where={"user_email": {"equals": username, "mode": "insensitive"}} ), ) diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py index d9600e1a4b..00f276dc97 100644 --- a/litellm/proxy/auth/model_checks.py +++ b/litellm/proxy/auth/model_checks.py @@ -6,6 +6,7 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.credential_accessor import CredentialAccessor from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth +from litellm.repositories.object_permission_repository import ObjectPermissionRepository from litellm.router import Router from litellm.router_utils.fallback_event_handlers import get_fallback_model_group from litellm.types.router import CredentialLiteLLMParams, LiteLLM_Params @@ -86,7 +87,7 @@ async def get_mcp_server_ids( # Make a direct SQL query to get just the mcp_servers try: - result = await prisma_client.db.litellm_objectpermissiontable.find_unique( + result = await ObjectPermissionRepository(prisma_client).table.find_unique( where={"object_permission_id": user_api_key_dict.object_permission_id}, ) if result and result.mcp_servers: diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index a5501fefa4..a970e0ddee 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -21,9 +21,9 @@ from fastapi.security.api_key import APIKeyHeader import litellm from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._service_logger import ServiceLogging +from litellm.constants import LITELLM_PROXY_MASTER_KEY_ALIAS from litellm.integrations.otel.model.config import is_otel_v2_enabled from litellm.integrations.otel.runtime import phase_span, seed_request_identity -from litellm.constants import LITELLM_PROXY_MASTER_KEY_ALIAS from litellm.litellm_core_utils.dd_tracing import tracer from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value from litellm.proxy._types import * @@ -65,7 +65,6 @@ from litellm.proxy.auth.oauth2_check import Oauth2Handler from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.common_utils.cache_coordinator import EventDrivenCacheCoordinator -from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.common_utils.http_parsing_utils import ( _read_request_body, _safe_get_request_headers, @@ -73,12 +72,14 @@ from litellm.proxy.common_utils.http_parsing_utils import ( populate_request_with_path_params, ) from litellm.proxy.common_utils.realtime_utils import _realtime_request_body +from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup from litellm.proxy.utils import ( PrismaClient, ProxyLogging, normalize_route_for_root_path, ) +from litellm.repositories.table_repositories import TeamMembershipRepository from litellm.secret_managers.main import get_secret_bool from litellm.types.services import ServiceTypes @@ -1797,7 +1798,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 _team_id = valid_token.team_id if _user_id is not None and _team_id is not None: - _db_member = await prisma_client.db.litellm_teammembership.find_first( + _db_member = await TeamMembershipRepository( + prisma_client + ).table.find_first( where={ "user_id": _user_id, "team_id": _team_id, diff --git a/litellm/proxy/common_utils/expired_ui_session_key_cleanup_manager.py b/litellm/proxy/common_utils/expired_ui_session_key_cleanup_manager.py index 67a2456746..4f3e26ab5f 100644 --- a/litellm/proxy/common_utils/expired_ui_session_key_cleanup_manager.py +++ b/litellm/proxy/common_utils/expired_ui_session_key_cleanup_manager.py @@ -8,7 +8,6 @@ from datetime import datetime, timezone from typing import Any, Dict, List, Optional from litellm._logging import verbose_proxy_logger -from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.constants import ( EXPIRED_UI_SESSION_KEY_CLEANUP_JOB_NAME, LITELLM_EXPIRED_UI_SESSION_KEY_CLEANUP_BATCH_SIZE, @@ -16,11 +15,15 @@ from litellm.constants import ( UI_SESSION_TOKEN_TEAM_ID, ) from litellm.proxy._types import KeyRequest, LiteLLM_VerificationToken, UserAPIKeyAuth +from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks from litellm.proxy.management_endpoints.key_management_endpoints import ( delete_verification_tokens, ) from litellm.proxy.utils import PrismaClient +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) class ExpiredUISessionKeyCleanupManager: @@ -147,7 +150,7 @@ class ExpiredUISessionKeyCleanupManager: Find expired LiteLLM dashboard session keys. """ now = datetime.now(timezone.utc) - return await self.prisma_client.db.litellm_verificationtoken.find_many( + return await VerificationTokenRepository(self.prisma_client).table.find_many( where={ "team_id": UI_SESSION_TOKEN_TEAM_ID, "expires": {"lt": now}, diff --git a/litellm/proxy/common_utils/key_rotation_manager.py b/litellm/proxy/common_utils/key_rotation_manager.py index aaf39a7a19..d622f61249 100644 --- a/litellm/proxy/common_utils/key_rotation_manager.py +++ b/litellm/proxy/common_utils/key_rotation_manager.py @@ -24,6 +24,12 @@ from litellm.proxy.management_endpoints.key_management_endpoints import ( regenerate_key_fn, ) from litellm.proxy.utils import PrismaClient +from litellm.repositories.table_repositories import ( + DeprecatedVerificationTokenRepository, +) +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) class KeyRotationManager: @@ -124,20 +130,20 @@ class KeyRotationManager: """ now = datetime.now(timezone.utc) - keys_with_rotation = ( - await self.prisma_client.db.litellm_verificationtoken.find_many( - where={ - "auto_rotate": True, # Only keys marked for auto rotation - "OR": [ - { - "key_rotation_at": None - }, # Keys that need initial rotation time setup - { - "key_rotation_at": {"lte": now} - }, # Keys where rotation time has passed - ], - } - ) + keys_with_rotation = await VerificationTokenRepository( + self.prisma_client + ).table.find_many( + where={ + "auto_rotate": True, # Only keys marked for auto rotation + "OR": [ + { + "key_rotation_at": None + }, # Keys that need initial rotation time setup + { + "key_rotation_at": {"lte": now} + }, # Keys where rotation time has passed + ], + } ) return keys_with_rotation @@ -148,9 +154,9 @@ class KeyRotationManager: """ try: now = datetime.now(timezone.utc) - result = await self.prisma_client.db.litellm_deprecatedverificationtoken.delete_many( - where={"revoke_at": {"lt": now}} - ) + result = await DeprecatedVerificationTokenRepository( + self.prisma_client + ).table.delete_many(where={"revoke_at": {"lt": now}}) if result > 0: verbose_proxy_logger.debug( "Cleaned up %s expired deprecated key(s)", result @@ -206,7 +212,7 @@ class KeyRotationManager: # Calculate next rotation time using helper function now = datetime.now(timezone.utc) next_rotation_time = _calculate_key_rotation_time(key.rotation_interval) - await self.prisma_client.db.litellm_verificationtoken.update( + await VerificationTokenRepository(self.prisma_client).table.update( where={"token": response.token_id}, data={ "rotation_count": (key.rotation_count or 0) + 1, diff --git a/litellm/proxy/common_utils/reset_budget_job.py b/litellm/proxy/common_utils/reset_budget_job.py index 40c8caa49e..7c1dfe8dc9 100644 --- a/litellm/proxy/common_utils/reset_budget_job.py +++ b/litellm/proxy/common_utils/reset_budget_job.py @@ -14,6 +14,16 @@ from litellm.proxy._types import ( LiteLLM_VerificationToken, ) from litellm.proxy.utils import PrismaClient, ProxyLogging +from litellm.repositories.organization_repository import OrganizationRepository +from litellm.repositories.table_repositories import ( + EndUserRepository, + TagRepository, + TeamMembershipRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.types.services import ServiceTypes @@ -159,7 +169,7 @@ class ResetBudgetJob: """ return await self._cascade_reset_spend_for_budget_link( budgets_to_reset=budgets_to_reset, - table=self.prisma_client.db.litellm_teammembership, + table=TeamMembershipRepository(self.prisma_client).table, counter_key_fn=lambda m: f"spend:team_member:{m.user_id}:{m.team_id}", log_subject="team memberships", cache_key_fn=lambda m: f"{m.team_id}_{m.user_id}", @@ -176,7 +186,7 @@ class ResetBudgetJob: """ return await self._cascade_reset_spend_for_budget_link( budgets_to_reset=budgets_to_reset, - table=self.prisma_client.db.litellm_verificationtoken, + table=VerificationTokenRepository(self.prisma_client).table, counter_key_fn=lambda k: f"spend:key:{k.token}", log_subject="keys", extra_where={"budget_duration": None, "spend": {"gt": 0}}, @@ -191,7 +201,7 @@ class ResetBudgetJob: """ return await self._cascade_reset_spend_for_budget_link( budgets_to_reset=budgets_to_reset, - table=self.prisma_client.db.litellm_organizationtable, + table=OrganizationRepository(self.prisma_client).table, counter_key_fn=lambda o: f"spend:org:{o.organization_id}", log_subject="orgs", extra_where={"spend": {"gt": 0}}, @@ -217,7 +227,7 @@ class ResetBudgetJob: """ return await self._cascade_reset_spend_for_budget_link( budgets_to_reset=budgets_to_reset, - table=self.prisma_client.db.litellm_tagtable, + table=TagRepository(self.prisma_client).table, counter_key_fn=lambda t: f"spend:tag:{t.tag_name}", log_subject="tags", extra_where={"spend": {"gt": 0}}, @@ -406,7 +416,7 @@ class ResetBudgetJob: rely on the default budget (litellm.max_end_user_budget_id) applied in-memory during auth checks. """ - rows = await self.prisma_client.db.litellm_endusertable.find_many( + rows = await EndUserRepository(self.prisma_client).table.find_many( where={ "budget_id": None, "spend": {"gt": 0}, @@ -824,7 +834,7 @@ class ResetBudgetJob: ): changed = True if changed: - await self.prisma_client.db.litellm_verificationtoken.update( + await VerificationTokenRepository(self.prisma_client).table.update( where={"token": row["token"]}, data={"budget_limits": json.dumps(windows)}, # type: ignore[arg-type] ) @@ -852,7 +862,7 @@ class ResetBudgetJob: ): changed = True if changed: - await self.prisma_client.db.litellm_teamtable.update( + await TeamRepository(self.prisma_client).table.update( where={"team_id": row["team_id"]}, data={"budget_limits": json.dumps(windows)}, # type: ignore[arg-type] ) diff --git a/litellm/proxy/container_endpoints/ownership.py b/litellm/proxy/container_endpoints/ownership.py index e0015e112e..8118d53b9f 100644 --- a/litellm/proxy/container_endpoints/ownership.py +++ b/litellm/proxy/container_endpoints/ownership.py @@ -12,6 +12,7 @@ from litellm.proxy.common_utils.resource_ownership import ( is_proxy_admin, user_can_access_resource_owner, ) +from litellm.repositories.table_repositories import ManagedObjectRepository from litellm.responses.utils import ResponsesAPIRequestUtils CONTAINER_OBJECT_PURPOSE = "container" @@ -213,7 +214,7 @@ async def record_container_owner( ) return response - table = prisma_client.db.litellm_managedobjecttable + table = ManagedObjectRepository(prisma_client).table existing = await table.find_unique(where={"model_object_id": model_object_id}) if existing is not None: if getattr(existing, "file_purpose", None) != CONTAINER_OBJECT_PURPOSE: @@ -273,7 +274,7 @@ async def _get_container_owner( if prisma_client is None: return None - row = await prisma_client.db.litellm_managedobjecttable.find_first( + row = await ManagedObjectRepository(prisma_client).table.find_first( where={ "model_object_id": model_object_id, "file_purpose": CONTAINER_OBJECT_PURPOSE, @@ -319,7 +320,7 @@ async def _get_stored_container_id( if prisma_client is None: return None - row = await prisma_client.db.litellm_managedobjecttable.find_first( + row = await ManagedObjectRepository(prisma_client).table.find_first( where={ "model_object_id": model_object_id, "file_purpose": CONTAINER_OBJECT_PURPOSE, @@ -411,7 +412,7 @@ async def _get_allowed_container_ids( if prisma_client is None: return set() - rows = await prisma_client.db.litellm_managedobjecttable.find_many( + rows = await ManagedObjectRepository(prisma_client).table.find_many( where={ "file_purpose": CONTAINER_OBJECT_PURPOSE, "created_by": {"in": owner_scopes}, diff --git a/litellm/proxy/credential_endpoints/endpoints.py b/litellm/proxy/credential_endpoints/endpoints.py index 2d05270e2e..a716857111 100644 --- a/litellm/proxy/credential_endpoints/endpoints.py +++ b/litellm/proxy/credential_endpoints/endpoints.py @@ -14,6 +14,7 @@ from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper from litellm.proxy.utils import handle_exception_on_proxy, jsonify_object +from litellm.repositories.credentials_repository import CredentialsRepository from litellm.types.utils import CreateCredentialItem, CredentialItem router = APIRouter() @@ -96,7 +97,7 @@ async def create_credential( ) credentials_dict = encrypted_credential.model_dump() credentials_dict_jsonified = jsonify_object(credentials_dict) - await prisma_client.db.litellm_credentialstable.create( + await CredentialsRepository(prisma_client).create( data={ **credentials_dict_jsonified, "created_by": user_api_key_dict.user_id, @@ -245,9 +246,7 @@ async def delete_credential( status_code=500, detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - await prisma_client.db.litellm_credentialstable.delete( - where={"credential_name": credential_name} - ) + await CredentialsRepository(prisma_client).delete_by_name(credential_name) ## DELETE FROM LITELLM ## litellm.credential_list = [ @@ -326,15 +325,14 @@ async def update_credential( status_code=500, detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - db_credential = await prisma_client.db.litellm_credentialstable.find_unique( - where={"credential_name": credential_name}, - ) + credentials_repository = CredentialsRepository(prisma_client) + db_credential = await credentials_repository.find_by_name(credential_name) if db_credential is None: raise HTTPException(status_code=404, detail="Credential not found in DB.") merged_credential = update_db_credential(db_credential, credential) credential_object_jsonified = jsonify_object(merged_credential.model_dump()) - await prisma_client.db.litellm_credentialstable.update( - where={"credential_name": credential_name}, + await credentials_repository.update_by_name( + credential_name, data={ **credential_object_jsonified, "updated_by": user_api_key_dict.user_id, diff --git a/litellm/proxy/db/spend_counter_reseed.py b/litellm/proxy/db/spend_counter_reseed.py index e7c5fa3f72..2226aeb4b0 100644 --- a/litellm/proxy/db/spend_counter_reseed.py +++ b/litellm/proxy/db/spend_counter_reseed.py @@ -20,6 +20,16 @@ from typing import TYPE_CHECKING, ClassVar, Optional from litellm._logging import verbose_proxy_logger from litellm.constants import SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE from litellm.litellm_core_utils.duration_parser import duration_in_seconds +from litellm.repositories.organization_repository import OrganizationRepository +from litellm.repositories.table_repositories import ( + SpendLogsRepository, + TeamMembershipRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) if TYPE_CHECKING: from litellm.caching.dual_cache import DualCache @@ -83,25 +93,25 @@ class SpendCounterReseed: try: if counter_key.startswith("spend:key:"): token = counter_key[len("spend:key:") :] - row = await prisma_client.db.litellm_verificationtoken.find_unique( - where={"token": token} - ) + row = await VerificationTokenRepository( + prisma_client + ).table.find_unique(where={"token": token}) elif counter_key.startswith("spend:team_member:"): suffix = counter_key[len("spend:team_member:") :] if ":" not in suffix: return None user_id, team_id = suffix.rsplit(":", 1) - row = await prisma_client.db.litellm_teammembership.find_unique( + row = await TeamMembershipRepository(prisma_client).table.find_unique( where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}} ) elif counter_key.startswith("spend:team:"): team_id = counter_key[len("spend:team:") :] - row = await prisma_client.db.litellm_teamtable.find_unique( + row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) elif counter_key.startswith("spend:user:"): user_id = counter_key[len("spend:user:") :] - row = await prisma_client.db.litellm_usertable.find_unique( + row = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_id} ) elif counter_key.startswith("spend:end_user:"): @@ -110,7 +120,7 @@ class SpendCounterReseed: return None elif counter_key.startswith("spend:org:"): org_id = counter_key[len("spend:org:") :] - row = await prisma_client.db.litellm_organizationtable.find_unique( + row = await OrganizationRepository(prisma_client).table.find_unique( where={"organization_id": org_id} ) else: @@ -243,7 +253,7 @@ class SpendCounterReseed: return None try: - response = await prisma_client.db.litellm_spendlogs.group_by( + response = await SpendLogsRepository(prisma_client).table.group_by( by=[group_field], where=where, # type: ignore[arg-type] sum={"spend": True}, diff --git a/litellm/proxy/db/spend_log_tool_index.py b/litellm/proxy/db/spend_log_tool_index.py index 835d76e0ee..77c06a465f 100644 --- a/litellm/proxy/db/spend_log_tool_index.py +++ b/litellm/proxy/db/spend_log_tool_index.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Set from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.safe_json_loads import safe_json_loads from litellm.proxy.utils import PrismaClient +from litellm.repositories.table_repositories import SpendLogToolIndexRepository def _add_tool_calls_to_set(tool_calls: Any, out: Set[str]) -> None: @@ -141,7 +142,7 @@ async def process_spend_logs_tool_usage( } ) if index_data: - await prisma_client.db.litellm_spendlogtoolindex.create_many( + await SpendLogToolIndexRepository(prisma_client).table.create_many( data=index_data, skip_duplicates=True, ) diff --git a/litellm/proxy/db/tool_registry_writer.py b/litellm/proxy/db/tool_registry_writer.py index 6b34c974cf..bbcc7396d6 100644 --- a/litellm/proxy/db/tool_registry_writer.py +++ b/litellm/proxy/db/tool_registry_writer.py @@ -11,6 +11,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from litellm._logging import verbose_proxy_logger from litellm.proxy._types import ToolDiscoveryQueueItem +from litellm.repositories.object_permission_repository import ObjectPermissionRepository +from litellm.repositories.table_repositories import ToolRepository from litellm.types.tool_management import ( LiteLLM_ToolTableRow, ToolPolicyOverrideRow, @@ -84,7 +86,7 @@ async def batch_upsert_tools( if not data: return now = datetime.now(timezone.utc) - table = prisma_client.db.litellm_tooltable + table = ToolRepository(prisma_client).table for item in data: tool_name = item.get("tool_name", "") origin = item.get("origin") or "user_defined" @@ -134,7 +136,7 @@ async def list_tools( """Return all tools, optionally filtered by input_policy.""" try: where = {"input_policy": input_policy} if input_policy is not None else {} - rows = await prisma_client.db.litellm_tooltable.find_many( + rows = await ToolRepository(prisma_client).table.find_many( where=where, order={"created_at": "desc"}, ) @@ -150,7 +152,7 @@ async def get_tool( ) -> Optional[LiteLLM_ToolTableRow]: """Return a single tool row by tool_name.""" try: - row = await prisma_client.db.litellm_tooltable.find_unique( + row = await ToolRepository(prisma_client).table.find_unique( where={"tool_name": tool_name}, ) if row is None: @@ -192,7 +194,7 @@ async def update_tool_policy( if output_policy is not None: update_data["output_policy"] = output_policy - await prisma_client.db.litellm_tooltable.upsert( + await ToolRepository(prisma_client).table.upsert( where={"tool_name": tool_name}, data={ "create": create_data, @@ -217,7 +219,7 @@ async def get_tools_by_names( if not tool_names: return {} try: - rows = await prisma_client.db.litellm_tooltable.find_many( + rows = await ToolRepository(prisma_client).table.find_many( where={"tool_name": {"in": tool_names}}, ) return { @@ -244,7 +246,7 @@ async def list_overrides_for_tool( """ out: List[ToolPolicyOverrideRow] = [] try: - perms = await prisma_client.db.litellm_objectpermissiontable.find_many( + perms = await ObjectPermissionRepository(prisma_client).table.find_many( where={"blocked_tools": {"has": tool_name}}, include={ "verification_tokens": True, @@ -307,7 +309,7 @@ class ToolPolicyRegistry: async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None: """Load all tool policies and object-permission blocked_tools from DB.""" try: - tools = await prisma_client.db.litellm_tooltable.find_many() + tools = await ToolRepository(prisma_client).table.find_many() self._tool_input_policies = { row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted" for row in tools @@ -317,7 +319,7 @@ class ToolPolicyRegistry: for row in tools } - perms = await prisma_client.db.litellm_objectpermissiontable.find_many() + perms = await ObjectPermissionRepository(prisma_client).table.find_many() self._blocked_tools_by_op_id = {} for row in perms: op_id = getattr(row, "object_permission_id", None) @@ -388,7 +390,7 @@ async def add_tool_to_object_permission_blocked( if not object_permission_id or not tool_name: return False try: - row = await prisma_client.db.litellm_objectpermissiontable.find_unique( + row = await ObjectPermissionRepository(prisma_client).table.find_unique( where={"object_permission_id": object_permission_id}, ) if row is None: @@ -397,7 +399,7 @@ async def add_tool_to_object_permission_blocked( if tool_name in current: return True current.append(tool_name) - await prisma_client.db.litellm_objectpermissiontable.update( + await ObjectPermissionRepository(prisma_client).table.update( where={"object_permission_id": object_permission_id}, data={"blocked_tools": current}, ) @@ -418,7 +420,7 @@ async def remove_tool_from_object_permission_blocked( if not object_permission_id or not tool_name: return False try: - row = await prisma_client.db.litellm_objectpermissiontable.find_unique( + row = await ObjectPermissionRepository(prisma_client).table.find_unique( where={"object_permission_id": object_permission_id}, ) if row is None: @@ -427,7 +429,7 @@ async def remove_tool_from_object_permission_blocked( if tool_name not in current: return False current = [t for t in current if t != tool_name] - await prisma_client.db.litellm_objectpermissiontable.update( + await ObjectPermissionRepository(prisma_client).table.update( where={"object_permission_id": object_permission_id}, data={"blocked_tools": current}, ) diff --git a/litellm/proxy/guardrails/guardrail_endpoints.py b/litellm/proxy/guardrails/guardrail_endpoints.py index e0e4bdcf4a..9f8ea58410 100644 --- a/litellm/proxy/guardrails/guardrail_endpoints.py +++ b/litellm/proxy/guardrails/guardrail_endpoints.py @@ -13,21 +13,21 @@ from urllib.parse import urlparse from fastapi import APIRouter, Depends, HTTPException, Request from pydantic import BaseModel -from litellm.proxy.common_utils.path_utils import safe_join - from litellm._logging import verbose_proxy_logger from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view +from litellm.proxy.common_utils.path_utils import safe_join from litellm.proxy.guardrails.guardrail_hooks.custom_code.sandbox import ( build_sandbox_globals, compile_sandboxed, ) from litellm.proxy.guardrails.guardrail_registry import GuardrailRegistry from litellm.proxy.guardrails.usage_endpoints import router as guardrails_usage_router +from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view +from litellm.repositories.table_repositories import GuardrailsRepository from litellm.types.guardrails import ( PII_ENTITY_CATEGORIES_MAP, ApplyGuardrailRequest, @@ -373,7 +373,7 @@ async def create_guardrail( # Configuration error — roll back the DB write so the guardrail isn't orphaned if prisma_client is not None: try: - await prisma_client.db.litellm_guardrailstable.delete( + await GuardrailsRepository(prisma_client).table.delete( where={"guardrail_id": guardrail_id} ) except Exception as rollback_err: @@ -705,7 +705,7 @@ async def register_guardrail( ) try: - existing = await prisma_client.db.litellm_guardrailstable.find_unique( + existing = await GuardrailsRepository(prisma_client).table.find_unique( where={"guardrail_name": request.guardrail_name} ) if existing is not None: @@ -732,7 +732,7 @@ async def register_guardrail( guardrail_info_str = safe_dumps(guardrail_info) try: - created = await prisma_client.db.litellm_guardrailstable.create( + created = await GuardrailsRepository(prisma_client).table.create( data={ "guardrail_name": request.guardrail_name, "litellm_params": litellm_params_str, @@ -874,7 +874,7 @@ async def list_guardrail_submissions( where_clause["team_id"] = {"in": visible_team_ids} # Single query: fetch team guardrails visible to the caller - all_team_rows = await prisma_client.db.litellm_guardrailstable.find_many( + all_team_rows = await GuardrailsRepository(prisma_client).table.find_many( where=where_clause, order={"created_at": "desc"}, ) @@ -945,7 +945,7 @@ async def get_guardrail_submission( is_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN try: - row = await prisma_client.db.litellm_guardrailstable.find_unique( + row = await GuardrailsRepository(prisma_client).table.find_unique( where={"guardrail_id": guardrail_id} ) if row is None: @@ -986,7 +986,7 @@ async def approve_guardrail_submission( raise HTTPException(status_code=500, detail="Prisma client not initialized") try: - row = await prisma_client.db.litellm_guardrailstable.find_unique( + row = await GuardrailsRepository(prisma_client).table.find_unique( where={"guardrail_id": guardrail_id} ) if row is None: @@ -1000,7 +1000,7 @@ async def approve_guardrail_submission( ) now = datetime.now(timezone.utc) - await prisma_client.db.litellm_guardrailstable.update( + await GuardrailsRepository(prisma_client).table.update( where={"guardrail_id": guardrail_id}, data={"status": "active", "reviewed_at": now, "updated_at": now}, ) @@ -1072,7 +1072,7 @@ async def reject_guardrail_submission( raise HTTPException(status_code=500, detail="Prisma client not initialized") try: - row = await prisma_client.db.litellm_guardrailstable.find_unique( + row = await GuardrailsRepository(prisma_client).table.find_unique( where={"guardrail_id": guardrail_id} ) if row is None: @@ -1086,7 +1086,7 @@ async def reject_guardrail_submission( ) now = datetime.now(timezone.utc) - await prisma_client.db.litellm_guardrailstable.update( + await GuardrailsRepository(prisma_client).table.update( where={"guardrail_id": guardrail_id}, data={"status": "rejected", "reviewed_at": now, "updated_at": now}, ) @@ -2288,10 +2288,10 @@ async def apply_guardrail( """ import traceback - from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing from litellm.litellm_core_utils.thread_pool_executor import ( executor as thread_pool_executor, ) + from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing from litellm.proxy.proxy_server import ( general_settings, proxy_config, diff --git a/litellm/proxy/guardrails/guardrail_registry.py b/litellm/proxy/guardrails/guardrail_registry.py index aafcc5f181..a80bb81789 100644 --- a/litellm/proxy/guardrails/guardrail_registry.py +++ b/litellm/proxy/guardrails/guardrail_registry.py @@ -11,12 +11,15 @@ from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.litellm_core_utils.safe_json_dumps import safe_dumps -from litellm.proxy.guardrails.guardrail_hooks.grayswan import GraySwanGuardrail +from litellm.proxy.guardrails.guardrail_hooks.grayswan import ( + GraySwanGuardrail, +) from litellm.proxy.guardrails.guardrail_hooks.grayswan import ( initialize_guardrail as initialize_grayswan, ) from litellm.proxy.types_utils.utils import get_instance_fn from litellm.proxy.utils import PrismaClient +from litellm.repositories.table_repositories import GuardrailsRepository from litellm.secret_managers.main import get_secret from litellm.types.guardrails import ( Guardrail, @@ -26,6 +29,9 @@ from litellm.types.guardrails import ( SupportedGuardrailIntegrations, ) +from .guardrail_hooks.llm_as_a_judge import ( + initialize_guardrail as initialize_llm_as_a_judge, +) from .guardrail_initializers import ( initialize_bedrock, initialize_hide_secrets, @@ -34,9 +40,6 @@ from .guardrail_initializers import ( initialize_presidio, initialize_tool_permission, ) -from .guardrail_hooks.llm_as_a_judge import ( - initialize_guardrail as initialize_llm_as_a_judge, -) guardrail_initializer_registry = { SupportedGuardrailIntegrations.BEDROCK.value: initialize_bedrock, @@ -257,7 +260,7 @@ class GuardrailRegistry: guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {})) # Create guardrail in DB - created_guardrail = await prisma_client.db.litellm_guardrailstable.create( + created_guardrail = await GuardrailsRepository(prisma_client).table.create( data={ "guardrail_name": guardrail_name, "litellm_params": litellm_params, @@ -283,7 +286,7 @@ class GuardrailRegistry: """ try: # Delete from DB - await prisma_client.db.litellm_guardrailstable.delete( + await GuardrailsRepository(prisma_client).table.delete( where={"guardrail_id": guardrail_id} ) @@ -311,7 +314,7 @@ class GuardrailRegistry: guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {})) # Update in DB - updated_guardrail = await prisma_client.db.litellm_guardrailstable.update( + updated_guardrail = await GuardrailsRepository(prisma_client).table.update( where={"guardrail_id": guardrail_id}, data={ "guardrail_name": guardrail_name, @@ -335,11 +338,11 @@ class GuardrailRegistry: Only rows with status == "active" are returned (pending_review and rejected are excluded). """ try: - guardrails_from_db = ( - await prisma_client.db.litellm_guardrailstable.find_many( - where={"status": "active"}, - order={"created_at": "desc"}, - ) + guardrails_from_db = await GuardrailsRepository( + prisma_client + ).table.find_many( + where={"status": "active"}, + order={"created_at": "desc"}, ) guardrails: List[Guardrail] = [] @@ -357,7 +360,7 @@ class GuardrailRegistry: Get a guardrail by its ID from the database """ try: - guardrail = await prisma_client.db.litellm_guardrailstable.find_unique( + guardrail = await GuardrailsRepository(prisma_client).table.find_unique( where={"guardrail_id": guardrail_id} ) @@ -375,7 +378,7 @@ class GuardrailRegistry: Get a guardrail by its name from the database """ try: - guardrail = await prisma_client.db.litellm_guardrailstable.find_unique( + guardrail = await GuardrailsRepository(prisma_client).table.find_unique( where={"guardrail_name": guardrail_name} ) diff --git a/litellm/proxy/guardrails/usage_endpoints.py b/litellm/proxy/guardrails/usage_endpoints.py index 529949c6dd..d8457cf9c8 100644 --- a/litellm/proxy/guardrails/usage_endpoints.py +++ b/litellm/proxy/guardrails/usage_endpoints.py @@ -12,6 +12,14 @@ from pydantic import BaseModel from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.repositories.table_repositories import ( + DailyGuardrailMetricsRepository, + DailyPolicyMetricsRepository, + GuardrailsRepository, + PolicyRepository, + SpendLogGuardrailIndexRepository, + SpendLogsRepository, +) router = APIRouter() @@ -272,10 +280,10 @@ async def guardrails_usage_overview( try: # Guardrails from DB - guardrails = await prisma_client.db.litellm_guardrailstable.find_many() + guardrails = await GuardrailsRepository(prisma_client).table.find_many() # Daily metrics in range - metrics = await prisma_client.db.litellm_dailyguardrailmetrics.find_many( + metrics = await DailyGuardrailMetricsRepository(prisma_client).table.find_many( where={"date": {"gte": start, "lte": end}} ) @@ -283,9 +291,9 @@ async def guardrails_usage_overview( start_prev = ( datetime.strptime(start, "%Y-%m-%d") - timedelta(days=7) ).strftime("%Y-%m-%d") - metrics_prev = await prisma_client.db.litellm_dailyguardrailmetrics.find_many( - where={"date": {"gte": start_prev, "lt": start}} - ) + metrics_prev = await DailyGuardrailMetricsRepository( + prisma_client + ).table.find_many(where={"date": {"gte": start_prev, "lt": start}}) agg = _aggregate_daily_metrics(metrics, "guardrail_id") prev_agg = _prev_fail_rates(metrics_prev, "guardrail_id") @@ -335,7 +343,7 @@ async def guardrails_usage_detail( end = end_date or now.strftime("%Y-%m-%d") start = start_date or (now - timedelta(days=7)).strftime("%Y-%m-%d") - guardrail = await prisma_client.db.litellm_guardrailstable.find_unique( + guardrail = await GuardrailsRepository(prisma_client).table.find_unique( where={"guardrail_id": guardrail_id} ) if not guardrail: @@ -349,13 +357,13 @@ async def guardrails_usage_detail( ) metric_ids = [i for i in (logical_id, guardrail_id) if i] - metrics = await prisma_client.db.litellm_dailyguardrailmetrics.find_many( + metrics = await DailyGuardrailMetricsRepository(prisma_client).table.find_many( where={ "guardrail_id": {"in": metric_ids}, "date": {"gte": start, "lte": end}, } ) - metrics_prev = await prisma_client.db.litellm_dailyguardrailmetrics.find_many( + metrics_prev = await DailyGuardrailMetricsRepository(prisma_client).table.find_many( where={ "guardrail_id": {"in": metric_ids}, "date": {"lt": start}, @@ -574,7 +582,7 @@ async def guardrails_usage_logs( # Query by both so we match regardless of which was written. effective_guardrail_ids: List[str] = [guardrail_id] if guardrail_id else [] if guardrail_id: - guardrail = await prisma_client.db.litellm_guardrailstable.find_unique( + guardrail = await GuardrailsRepository(prisma_client).table.find_unique( where={"guardrail_id": guardrail_id} ) if guardrail: @@ -585,19 +593,23 @@ async def guardrails_usage_logs( where = _build_usage_logs_where( effective_guardrail_ids or None, policy_id, start_date, end_date ) - index_rows = await prisma_client.db.litellm_spendlogguardrailindex.find_many( + index_rows = await SpendLogGuardrailIndexRepository( + prisma_client + ).table.find_many( where=where, order={"start_time": "desc"}, skip=(page - 1) * page_size, take=page_size + 1, ) - total = await prisma_client.db.litellm_spendlogguardrailindex.count(where=where) + total = await SpendLogGuardrailIndexRepository(prisma_client).table.count( + where=where + ) request_ids = [r.request_id for r in index_rows[:page_size]] if not request_ids: return UsageLogsResponse( logs=[], total=total, page=page, page_size=page_size ) - spend_logs = await prisma_client.db.litellm_spendlogs.find_many( + spend_logs = await SpendLogsRepository(prisma_client).table.find_many( where={"request_id": {"in": request_ids}} ) log_by_id = {s.request_id: s for s in spend_logs} @@ -645,11 +657,13 @@ async def policies_usage_overview( start = start_date or (now - timedelta(days=7)).strftime("%Y-%m-%d") try: - policies = await prisma_client.db.litellm_policytable.find_many() - metrics = await prisma_client.db.litellm_dailypolicymetrics.find_many( + policies = await PolicyRepository(prisma_client).table.find_many() + metrics = await DailyPolicyMetricsRepository(prisma_client).table.find_many( where={"date": {"gte": start, "lte": end}} ) - metrics_prev = await prisma_client.db.litellm_dailypolicymetrics.find_many( + metrics_prev = await DailyPolicyMetricsRepository( + prisma_client + ).table.find_many( where={ "date": { "gte": ( diff --git a/litellm/proxy/guardrails/usage_tracking.py b/litellm/proxy/guardrails/usage_tracking.py index 8907c9201a..c55c47ca77 100644 --- a/litellm/proxy/guardrails/usage_tracking.py +++ b/litellm/proxy/guardrails/usage_tracking.py @@ -10,6 +10,10 @@ from typing import Any, Dict, List, Optional from litellm._logging import verbose_proxy_logger from litellm.proxy.utils import PrismaClient +from litellm.repositories.table_repositories import ( + DailyGuardrailMetricsRepository, + SpendLogGuardrailIndexRepository, +) def _guardrail_status_to_action(status: Optional[str]) -> str: @@ -132,7 +136,7 @@ async def process_spend_logs_guardrail_usage( } ) try: - await prisma_client.db.litellm_spendlogguardrailindex.create_many( + await SpendLogGuardrailIndexRepository(prisma_client).table.create_many( data=index_data, skip_duplicates=True, ) @@ -146,7 +150,7 @@ async def process_spend_logs_guardrail_usage( n = int(agg["requests_evaluated"]) if n == 0: continue - await prisma_client.db.litellm_dailyguardrailmetrics.upsert( + await DailyGuardrailMetricsRepository(prisma_client).table.upsert( where={ "guardrail_id_date": { "guardrail_id": guardrail_id, diff --git a/litellm/proxy/hooks/user_management_event_hooks.py b/litellm/proxy/hooks/user_management_event_hooks.py index 08fa8d4dfa..c22fd1d657 100644 --- a/litellm/proxy/hooks/user_management_event_hooks.py +++ b/litellm/proxy/hooks/user_management_event_hooks.py @@ -3,7 +3,6 @@ Hooks that are triggered when a litellm user event occurs """ import asyncio -from litellm._uuid import uuid from datetime import datetime, timezone from typing import Optional @@ -11,6 +10,7 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger +from litellm._uuid import uuid from litellm.proxy._types import ( AUDIT_ACTIONS, CommonProxyErrors, @@ -24,6 +24,7 @@ from litellm.proxy._types import ( WebhookEvent, ) from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update +from litellm.repositories.user_repository import UserRepository class UserManagementEventHooks: @@ -57,7 +58,7 @@ class UserManagementEventHooks: try: if prisma_client is None: raise Exception(CommonProxyErrors.db_not_connected_error.value) - user_row: BaseModel = await prisma_client.db.litellm_usertable.find_first( + user_row: BaseModel = await UserRepository(prisma_client).table.find_first( where={"user_id": response.user_id} ) diff --git a/litellm/proxy/management_endpoints/access_group_endpoints.py b/litellm/proxy/management_endpoints/access_group_endpoints.py index 62a770f46a..65f7ffc908 100644 --- a/litellm/proxy/management_endpoints/access_group_endpoints.py +++ b/litellm/proxy/management_endpoints/access_group_endpoints.py @@ -19,6 +19,7 @@ from litellm.proxy.auth.auth_checks import ( from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler from litellm.proxy.utils import get_prisma_client_or_throw +from litellm.repositories.table_repositories import AccessGroupRepository from litellm.types.access_group import ( AccessGroupCreateRequest, AccessGroupResponse, @@ -386,7 +387,7 @@ async def list_access_groups( CommonProxyErrors.db_not_connected_error.value ) - records = await prisma_client.db.litellm_accessgrouptable.find_many( + records = await AccessGroupRepository(prisma_client).table.find_many( order={"created_at": "desc"} ) return [_record_to_response(r) for r in records] @@ -405,7 +406,7 @@ async def get_access_group( CommonProxyErrors.db_not_connected_error.value ) - record = await prisma_client.db.litellm_accessgrouptable.find_unique( + record = await AccessGroupRepository(prisma_client).table.find_unique( where={"access_group_id": access_group_id} ) if record is None: diff --git a/litellm/proxy/management_endpoints/budget_management_endpoints.py b/litellm/proxy/management_endpoints/budget_management_endpoints.py index 2eda1b30c5..698155a5c2 100644 --- a/litellm/proxy/management_endpoints/budget_management_endpoints.py +++ b/litellm/proxy/management_endpoints/budget_management_endpoints.py @@ -16,11 +16,12 @@ import math from fastapi import APIRouter, Depends, HTTPException -from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view from litellm.proxy.utils import jsonify_object +from litellm.repositories.budget_repository import BudgetRepository router = APIRouter() @@ -98,7 +99,7 @@ async def new_budget( budget_obj_json = budget_obj.model_dump(exclude_none=True) budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries try: - response = await prisma_client.db.litellm_budgettable.create( + response = await BudgetRepository(prisma_client).table.create( data={ **budget_obj_jsonified, # type: ignore "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, @@ -182,7 +183,7 @@ async def update_budget( except ValueError as e: raise HTTPException(status_code=400, detail={"error": str(e)}) - response = await prisma_client.db.litellm_budgettable.update( + response = await BudgetRepository(prisma_client).table.update( where={"budget_id": budget_obj.budget_id}, data={ **budget_obj.model_dump(exclude_unset=True), # type: ignore @@ -217,7 +218,7 @@ async def info_budget(data: BudgetRequest): "error": f"Specify list of budget id's to query. Passed in={data.budgets}" }, ) - response = await prisma_client.db.litellm_budgettable.find_many( + response = await BudgetRepository(prisma_client).table.find_many( where={"budget_id": {"in": data.budgets}}, ) @@ -261,7 +262,7 @@ async def budget_settings( ) ## get budget item from db - db_budget_row = await prisma_client.db.litellm_budgettable.find_first( + db_budget_row = await BudgetRepository(prisma_client).table.find_first( where={"budget_id": budget_id} ) @@ -327,7 +328,7 @@ async def list_budget( }, ) - response = await prisma_client.db.litellm_budgettable.find_many() + response = await BudgetRepository(prisma_client).table.find_many() return response @@ -366,7 +367,7 @@ async def delete_budget( }, ) - response = await prisma_client.db.litellm_budgettable.delete( + response = await BudgetRepository(prisma_client).table.delete( where={"budget_id": data.id} ) diff --git a/litellm/proxy/management_endpoints/cache_settings_endpoints.py b/litellm/proxy/management_endpoints/cache_settings_endpoints.py index 0a26b23bef..d8eb5dfee9 100644 --- a/litellm/proxy/management_endpoints/cache_settings_endpoints.py +++ b/litellm/proxy/management_endpoints/cache_settings_endpoints.py @@ -27,6 +27,7 @@ from litellm.proxy._types import ( UserAPIKeyAuth, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.repositories.table_repositories import CacheConfigRepository from litellm.types.management_endpoints import ( CACHE_SETTINGS_FIELDS, REDIS_TYPE_DESCRIPTIONS, @@ -159,7 +160,7 @@ class CacheSettingsManager: import json try: - cache_config = await prisma_client.db.litellm_cacheconfig.find_unique( + cache_config = await CacheConfigRepository(prisma_client).table.find_unique( where={"id": "cache_config"} ) if cache_config is not None and cache_config.cache_settings: @@ -274,7 +275,7 @@ async def get_cache_settings( # Try to get cache settings from database current_values = {} if prisma_client is not None: - cache_config = await prisma_client.db.litellm_cacheconfig.find_unique( + cache_config = await CacheConfigRepository(prisma_client).table.find_unique( where={"id": "cache_config"} ) if cache_config is not None and cache_config.cache_settings: @@ -417,7 +418,7 @@ async def update_cache_settings( # Snapshot the prior settings (key set only — values get redacted in # the audit row) so the audit-log entry shows which fields changed. - existing_row = await prisma_client.db.litellm_cacheconfig.find_unique( + existing_row = await CacheConfigRepository(prisma_client).table.find_unique( where={"id": "cache_config"} ) before_settings: Optional[Dict[str, Any]] = None @@ -434,7 +435,7 @@ async def update_cache_settings( ) # Save to database - await prisma_client.db.litellm_cacheconfig.upsert( + await CacheConfigRepository(prisma_client).table.upsert( where={"id": "cache_config"}, data={ "create": { diff --git a/litellm/proxy/management_endpoints/common_daily_activity.py b/litellm/proxy/management_endpoints/common_daily_activity.py index d173cd745b..92cc2008c7 100644 --- a/litellm/proxy/management_endpoints/common_daily_activity.py +++ b/litellm/proxy/management_endpoints/common_daily_activity.py @@ -8,6 +8,10 @@ from fastapi import HTTPException, status from litellm._logging import verbose_proxy_logger from litellm.proxy._types import CommonProxyErrors from litellm.proxy.utils import PrismaClient +from litellm.repositories.table_repositories import DeletedVerificationTokenRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.types.proxy.management_endpoints.common_daily_activity import ( BreakdownMetrics, DailySpendData, @@ -346,7 +350,7 @@ async def get_api_key_metadata( This ensures that key_alias and team_id are preserved in historical activity logs even after a key is deleted or regenerated. """ - key_records = await prisma_client.db.litellm_verificationtoken.find_many( + key_records = await VerificationTokenRepository(prisma_client).table.find_many( where={"token": {"in": list(api_keys)}} ) result = { @@ -357,11 +361,11 @@ async def get_api_key_metadata( missing_keys = api_keys - set(result.keys()) if missing_keys: try: - deleted_key_records = ( - await prisma_client.db.litellm_deletedverificationtoken.find_many( - where={"token": {"in": list(missing_keys)}}, - order={"deleted_at": "desc"}, - ) + deleted_key_records = await DeletedVerificationTokenRepository( + prisma_client + ).table.find_many( + where={"token": {"in": list(missing_keys)}}, + order={"deleted_at": "desc"}, ) # Use the most recent deleted record for each token (ordered by deleted_at desc) for k in deleted_key_records: diff --git a/litellm/proxy/management_endpoints/common_utils.py b/litellm/proxy/management_endpoints/common_utils.py index 31d831d773..458cba686e 100644 --- a/litellm/proxy/management_endpoints/common_utils.py +++ b/litellm/proxy/management_endpoints/common_utils.py @@ -17,10 +17,13 @@ from litellm.proxy._types import ( NewProjectRequest, UpdateProjectRequest, UserAPIKeyAuth, - user_api_key_has_admin_view as _user_has_admin_view, # noqa: F401 re-exported +) +from litellm.proxy._types import ( # noqa: F401 re-exported + user_api_key_has_admin_view as _user_has_admin_view, ) from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time from litellm.proxy.utils import _premium_user_check +from litellm.repositories.team_repository import TeamRepository if TYPE_CHECKING: from litellm.proxy._types import NewProjectRequest, UpdateProjectRequest @@ -205,7 +208,7 @@ async def _user_has_admin_privileges( # Check if user is team admin for any team if user_obj.teams is not None and len(user_obj.teams) > 0: # Get all teams user is in - teams = await prisma_client.db.litellm_teamtable.find_many( + teams = await TeamRepository(prisma_client).table.find_many( where={"team_id": {"in": user_obj.teams}} ) @@ -282,7 +285,7 @@ async def _team_admin_can_invite_user( if not target_user_obj.teams or len(target_user_obj.teams) == 0: return False - teams = await prisma_client.db.litellm_teamtable.find_many( + teams = await TeamRepository(prisma_client).table.find_many( where={"team_id": {"in": admin_user_obj.teams}} ) admin_team_ids = [ diff --git a/litellm/proxy/management_endpoints/config_override_endpoints.py b/litellm/proxy/management_endpoints/config_override_endpoints.py index 7f7aa485fb..97cb5eeddc 100644 --- a/litellm/proxy/management_endpoints/config_override_endpoints.py +++ b/litellm/proxy/management_endpoints/config_override_endpoints.py @@ -30,6 +30,7 @@ from litellm.proxy._types import ( UserAPIKeyAuth, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.repositories.table_repositories import ConfigOverridesRepository from litellm.types.llms.custom_http import httpxSpecialProvider from litellm.types.proxy.management_endpoints.config_overrides import ( ConfigOverrideSettingsResponse, @@ -254,7 +255,7 @@ async def update_hashicorp_vault_config( # Merge ALL fields the user didn't send: try DB first, fall back to env vars. # Omitted field = keep existing; empty string = clear/remove the field. - existing_record = await prisma_client.db.litellm_configoverrides.find_unique( + existing_record = await ConfigOverridesRepository(prisma_client).table.find_unique( where={"config_type": "hashicorp_vault"} ) existing_decrypted: Optional[Dict[str, Any]] = None @@ -321,7 +322,7 @@ async def update_hashicorp_vault_config( # Only persist to DB after successful init encrypted_data = proxy_config._encrypt_env_variables(config_data) config_value = safe_dumps(encrypted_data) - await prisma_client.db.litellm_configoverrides.upsert( + await ConfigOverridesRepository(prisma_client).table.upsert( where={"config_type": "hashicorp_vault"}, data={ "create": { @@ -391,7 +392,7 @@ async def get_hashicorp_vault_config( field_schema = _build_field_schema(HashicorpVaultConfig) # Try to load from DB - db_record = await prisma_client.db.litellm_configoverrides.find_unique( + db_record = await ConfigOverridesRepository(prisma_client).table.find_unique( where={"config_type": "hashicorp_vault"} ) @@ -448,7 +449,7 @@ async def delete_hashicorp_vault_config( # Capture the prior config before delete so the audit-log row can # show *what* was removed (keys only — values get redacted). - existing_record = await prisma_client.db.litellm_configoverrides.find_unique( + existing_record = await ConfigOverridesRepository(prisma_client).table.find_unique( where={"config_type": "hashicorp_vault"} ) before_config: Optional[Dict[str, Any]] = None @@ -463,7 +464,7 @@ async def delete_hashicorp_vault_config( # Delete DB record if it exists — ignore if not found deleted = False try: - await prisma_client.db.litellm_configoverrides.delete( + await ConfigOverridesRepository(prisma_client).table.delete( where={"config_type": "hashicorp_vault"} ) deleted = True diff --git a/litellm/proxy/management_endpoints/customer_endpoints.py b/litellm/proxy/management_endpoints/customer_endpoints.py index 1fd8320db2..f1a34bb0ed 100644 --- a/litellm/proxy/management_endpoints/customer_endpoints.py +++ b/litellm/proxy/management_endpoints/customer_endpoints.py @@ -17,8 +17,8 @@ import fastapi from fastapi import APIRouter, Depends, HTTPException, Request import litellm -from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity @@ -27,6 +27,8 @@ from litellm.proxy.management_helpers.object_permission_utils import ( handle_update_object_permission_common, ) from litellm.proxy.utils import handle_exception_on_proxy +from litellm.repositories.budget_repository import BudgetRepository +from litellm.repositories.table_repositories import EndUserRepository from litellm.types.proxy.management_endpoints.common_daily_activity import ( SpendAnalyticsPaginatedResponse, ) @@ -68,7 +70,7 @@ async def block_user(data: BlockUsers): records = [] if prisma_client is not None: for id in data.user_ids: - record = await prisma_client.db.litellm_endusertable.upsert( + record = await EndUserRepository(prisma_client).table.upsert( where={"user_id": id}, # type: ignore data={ "create": {"user_id": id, "blocked": True}, # type: ignore @@ -337,7 +339,7 @@ async def new_end_user( _new_budget = new_budget_request(data) if _new_budget is not None: try: - budget_record = await prisma_client.db.litellm_budgettable.create( + budget_record = await BudgetRepository(prisma_client).table.create( data={ **_new_budget.model_dump(exclude_unset=True), "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore @@ -373,7 +375,7 @@ async def new_end_user( new_end_user_obj.pop("object_permission", None) ## WRITE TO DB ## - end_user_record = await prisma_client.db.litellm_endusertable.create( + end_user_record = await EndUserRepository(prisma_client).table.create( data=new_end_user_obj, # type: ignore include={"litellm_budget_table": True, "object_permission": True}, ) @@ -446,7 +448,7 @@ async def end_user_info( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - user_info = await prisma_client.db.litellm_endusertable.find_first( + user_info = await EndUserRepository(prisma_client).table.find_first( where={"user_id": end_user_id}, include={"litellm_budget_table": True, "object_permission": True}, ) @@ -569,7 +571,7 @@ async def update_end_user( non_default_values[k] = v ## Get end user table data ## - end_user_table_data = await prisma_client.db.litellm_endusertable.find_first( + end_user_table_data = await EndUserRepository(prisma_client).table.find_first( where={"user_id": data.user_id}, include={"litellm_budget_table": True} ) @@ -613,17 +615,17 @@ async def update_end_user( if budget_table_data: if end_user_budget_table is None: ## Create new budget ## - budget_table_data_record = ( - await prisma_client.db.litellm_budgettable.create( - data={ - **budget_table_data, - "created_by": user_api_key_dict.user_id - or litellm_proxy_admin_name, - "updated_by": user_api_key_dict.user_id - or litellm_proxy_admin_name, - }, - include={"end_users": True}, - ) + budget_table_data_record = await BudgetRepository( + prisma_client + ).table.create( + data={ + **budget_table_data, + "created_by": user_api_key_dict.user_id + or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id + or litellm_proxy_admin_name, + }, + include={"end_users": True}, ) update_end_user_table_data["budget_id"] = ( @@ -631,11 +633,11 @@ async def update_end_user( ) else: ## Update existing budget ## - budget_table_data_record = ( - await prisma_client.db.litellm_budgettable.update( - where={"budget_id": end_user_budget_table.budget_id}, - data=budget_table_data, - ) + budget_table_data_record = await BudgetRepository( + prisma_client + ).table.update( + where={"budget_id": end_user_budget_table.budget_id}, + data=budget_table_data, ) ## Update user table, with update params + new budget id (if set) ## @@ -652,7 +654,7 @@ async def update_end_user( if data.user_id is not None and len(data.user_id) > 0: update_end_user_table_data["user_id"] = data.user_id # type: ignore verbose_proxy_logger.debug("In update customer, user_id condition block.") - response = await prisma_client.db.litellm_endusertable.update( + response = await EndUserRepository(prisma_client).table.update( where={"user_id": data.user_id}, data=update_end_user_table_data, include={"litellm_budget_table": True, "object_permission": True} # type: ignore ) if response is None: @@ -737,7 +739,7 @@ async def delete_end_user( and len(data.user_ids) > 0 ): # First check if all users exist - existing_users = await prisma_client.db.litellm_endusertable.find_many( + existing_users = await EndUserRepository(prisma_client).table.find_many( where={"user_id": {"in": data.user_ids}} ) existing_user_ids = {user.user_id for user in existing_users} @@ -756,7 +758,7 @@ async def delete_end_user( ) # All users exist, proceed with deletion - response = await prisma_client.db.litellm_endusertable.delete_many( + response = await EndUserRepository(prisma_client).table.delete_many( where={"user_id": {"in": data.user_ids}} ) verbose_proxy_logger.debug( @@ -828,7 +830,7 @@ async def list_end_user( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - response = await prisma_client.db.litellm_endusertable.find_many( + response = await EndUserRepository(prisma_client).table.find_many( include={"litellm_budget_table": True, "object_permission": True} ) @@ -903,7 +905,7 @@ async def get_customer_daily_activity( where_condition = {} if end_user_ids_list: where_condition["user_id"] = {"in": list(end_user_ids_list)} - end_user_aliases = await prisma_client.db.litellm_endusertable.find_many( + end_user_aliases = await EndUserRepository(prisma_client).table.find_many( where=where_condition ) end_user_alias_metadata = {e.user_id: {"alias": e.alias} for e in end_user_aliases} diff --git a/litellm/proxy/management_endpoints/fallback_management_endpoints.py b/litellm/proxy/management_endpoints/fallback_management_endpoints.py index ffb12111d8..1333122c87 100644 --- a/litellm/proxy/management_endpoints/fallback_management_endpoints.py +++ b/litellm/proxy/management_endpoints/fallback_management_endpoints.py @@ -27,6 +27,7 @@ else: # fastapi is only required for proxy, not for SDK usage pass +from litellm.repositories.config_repository import ConfigRepository from litellm.types.management_endpoints.router_settings_endpoints import ( FallbackCreateRequest, FallbackDeleteResponse, @@ -157,7 +158,7 @@ async def create_fallback( # Save to database - convert router_settings to JSON string router_settings_json = json.dumps(router_settings) - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": "router_settings"}, data={ "create": { @@ -336,7 +337,7 @@ async def delete_fallback( # Save to database - convert router_settings to JSON string router_settings_json = json.dumps(router_settings) - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": "router_settings"}, data={ "create": { diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 7b8f0f72e1..b3a5c66e9e 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -43,6 +43,17 @@ from litellm.proxy.management_endpoints.key_management_endpoints import ( ) from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.utils import handle_exception_on_proxy, hash_password +from litellm.repositories.organization_repository import OrganizationRepository +from litellm.repositories.table_repositories import ( + InvitationLinkRepository, + OrganizationMembershipRepository, + TeamMembershipRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.types.proxy.management_endpoints.common_daily_activity import ( SpendAnalyticsPaginatedResponse, ) @@ -154,7 +165,7 @@ async def _check_duplicate_user_field( if case_insensitive: where_clause[field_name]["mode"] = "insensitive" - existing_user = await prisma_client.db.litellm_usertable.find_first( + existing_user = await UserRepository(prisma_client).table.find_first( where=where_clause ) @@ -434,7 +445,7 @@ async def new_user( await _check_duplicate_user_email(data.user_email, prisma_client) # Check if license is over limit - total_users = await prisma_client.db.litellm_usertable.count() + total_users = await UserRepository(prisma_client).table.count() if total_users and _license_check.is_over_limit(total_users=total_users): raise HTTPException( status_code=403, @@ -851,7 +862,7 @@ async def _check_user_info_v2_access( # Helper: fetch the target user row (reused across branches) async def _fetch_target_user(): - return await prisma_client.db.litellm_usertable.find_unique( + return await UserRepository(prisma_client).table.find_unique( where={"user_id": target_user_id} ) @@ -866,7 +877,7 @@ async def _check_user_info_v2_access( # Rule 3: Team admins can look up users in their teams if user_api_key_dict.user_id is not None: # Get caller's teams - caller_user = await prisma_client.db.litellm_usertable.find_unique( + caller_user = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_api_key_dict.user_id} ) if caller_user is not None and caller_user.teams: @@ -876,7 +887,7 @@ async def _check_user_info_v2_access( return None # Get all teams the caller belongs to - teams = await prisma_client.db.litellm_teamtable.find_many( + teams = await TeamRepository(prisma_client).table.find_many( where={"team_id": {"in": caller_user.teams}} ) for team in teams: @@ -1165,7 +1176,7 @@ async def _schedule_user_update_audit_log( if prisma_client is None: return try: - updated_user_row = await prisma_client.db.litellm_usertable.find_first( + updated_user_row = await UserRepository(prisma_client).table.find_first( where={"user_id": response["user_id"]} ) if updated_user_row: @@ -1255,11 +1266,11 @@ async def _update_single_user_helper( existing_user_row: Optional[BaseModel] = None if user_request.user_id: - existing_user_row = await prisma_client.db.litellm_usertable.find_first( + existing_user_row = await UserRepository(prisma_client).table.find_first( where={"user_id": user_request.user_id} ) elif user_request.user_email: - existing_user_row = await prisma_client.db.litellm_usertable.find_first( + existing_user_row = await UserRepository(prisma_client).table.find_first( where={"user_email": user_request.user_email} ) @@ -1640,7 +1651,7 @@ async def bulk_user_update( detail="Only proxy admins can update all users at once.", ) # Optimized path for updating all users directly in database - all_users_in_db = await prisma_client.db.litellm_usertable.find_many( + all_users_in_db = await UserRepository(prisma_client).table.find_many( order={"created_at": "desc"} ) @@ -1676,7 +1687,7 @@ async def bulk_user_update( try: # Perform bulk database update - await prisma_client.db.litellm_usertable.update_many( + await UserRepository(prisma_client).table.update_many( where={}, data=non_default_values # Update all users ) @@ -1783,7 +1794,7 @@ async def get_user_key_counts( # Get count for each user_id individually for user_id in user_ids: - count = await prisma_client.db.litellm_verificationtoken.count( + count = await VerificationTokenRepository(prisma_client).table.count( where={ "user_id": user_id, "OR": [ @@ -2056,7 +2067,7 @@ async def get_users( else None ) - users = await prisma_client.db.litellm_usertable.find_many( + users = await UserRepository(prisma_client).table.find_many( where=where_conditions, skip=skip, take=page_size, @@ -2066,7 +2077,9 @@ async def get_users( ) # Get total count of user rows - total_count = await prisma_client.db.litellm_usertable.count(where=where_conditions) + total_count = await UserRepository(prisma_client).table.count( + where=where_conditions + ) # Get key count for each user if users is not None: @@ -2137,14 +2150,14 @@ async def delete_user( from litellm.proxy.management_endpoints.team_endpoints import ( _cleanup_members_with_roles, ) + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) from litellm.proxy.proxy_server import ( create_audit_log_for_update, litellm_proxy_admin_name, prisma_client, ) - from litellm.proxy.management_helpers.audit_logs import ( - get_audit_log_changed_by, - ) if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No db connected"}) @@ -2164,7 +2177,7 @@ async def delete_user( caller_admin_org_ids: set = set() if not caller_is_proxy_admin: caller_memberships = ( - await prisma_client.db.litellm_organizationmembership.find_many( + await OrganizationMembershipRepository(prisma_client).table.find_many( where={ "user_id": user_api_key_dict.user_id, "user_role": LitellmUserRoles.ORG_ADMIN.value, @@ -2188,11 +2201,9 @@ async def delete_user( # an N+1 DB call when delete_user is called with a large user_ids list. target_org_ids_by_user: Dict[str, set] = {} if not caller_is_proxy_admin: - all_target_memberships = ( - await prisma_client.db.litellm_organizationmembership.find_many( - where={"user_id": {"in": data.user_ids}} - ) - ) + all_target_memberships = await OrganizationMembershipRepository( + prisma_client + ).table.find_many(where={"user_id": {"in": data.user_ids}}) for m in all_target_memberships: if not m.organization_id: continue @@ -2200,7 +2211,7 @@ async def delete_user( # check that all teams passed exist for user_id in data.user_ids: - user_row = await prisma_client.db.litellm_usertable.find_unique( + user_row = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_id} ) @@ -2254,7 +2265,7 @@ async def delete_user( ) ## CLEANUP MEMBERS_WITH_ROLES - fetch_all_teams = await prisma_client.db.litellm_teamtable.find_many( + fetch_all_teams = await TeamRepository(prisma_client).table.find_many( where={"team_id": {"in": user_row.teams}} ) teams_to_update = [] @@ -2277,19 +2288,19 @@ async def delete_user( ## update teams for team in teams_to_update: - await prisma_client.db.litellm_teamtable.update( + await TeamRepository(prisma_client).table.update( where={"team_id": team.team_id}, data={"members_with_roles": team.members_with_roles}, ) # End of Audit logging ## DELETE ASSOCIATED KEYS - await prisma_client.db.litellm_verificationtoken.delete_many( + await VerificationTokenRepository(prisma_client).table.delete_many( where={"user_id": {"in": data.user_ids}} ) ## DELETE ASSOCIATED INVITATION LINKS - await prisma_client.db.litellm_invitationlink.delete_many( + await InvitationLinkRepository(prisma_client).table.delete_many( where={ "OR": [ {"user_id": {"in": data.user_ids}}, @@ -2300,17 +2311,17 @@ async def delete_user( ) ## DELETE ASSOCIATED ORGANIZATION MEMBERSHIPS - await prisma_client.db.litellm_organizationmembership.delete_many( + await OrganizationMembershipRepository(prisma_client).table.delete_many( where={"user_id": {"in": data.user_ids}} ) ## DELETE ASSOCIATED TEAM MEMBERSHIPS - await prisma_client.db.litellm_teammembership.delete_many( + await TeamMembershipRepository(prisma_client).table.delete_many( where={"user_id": {"in": data.user_ids}} ) ## DELETE USERS - deleted_users = await prisma_client.db.litellm_usertable.delete_many( + deleted_users = await UserRepository(prisma_client).table.delete_many( where={"user_id": {"in": data.user_ids}} ) @@ -2340,16 +2351,18 @@ async def add_internal_user_to_organization( try: # Check if organization_id exists - organization_row = await prisma_client.db.litellm_organizationtable.find_unique( - where={"organization_id": organization_id} - ) + organization_row = await OrganizationRepository( + prisma_client + ).table.find_unique(where={"organization_id": organization_id}) if organization_row is None: raise Exception( f"Organization not found, passed organization_id={organization_id}" ) # Create a new organization membership entry - new_membership = await prisma_client.db.litellm_organizationmembership.create( + new_membership = await OrganizationMembershipRepository( + prisma_client + ).table.create( data={ "user_id": user_id, "organization_id": organization_id, @@ -2559,13 +2572,13 @@ async def ui_view_users( } # Query users with pagination and filters - users: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_usertable.find_many( - where=where_conditions, - skip=skip, - take=page_size, - order={"created_at": "desc"}, - ) + users: Optional[List[BaseModel]] = await UserRepository( + prisma_client + ).table.find_many( + where=where_conditions, + skip=skip, + take=page_size, + order={"created_at": "desc"}, ) if not users: diff --git a/litellm/proxy/management_endpoints/jwt_key_mapping_endpoints.py b/litellm/proxy/management_endpoints/jwt_key_mapping_endpoints.py index 1ee5bfb022..a5a364c367 100644 --- a/litellm/proxy/management_endpoints/jwt_key_mapping_endpoints.py +++ b/litellm/proxy/management_endpoints/jwt_key_mapping_endpoints.py @@ -11,6 +11,7 @@ from litellm.proxy._types import ( ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view +from litellm.repositories.table_repositories import JWTKeyMappingRepository router = APIRouter() @@ -61,7 +62,7 @@ async def create_jwt_key_mapping( if data.description is not None: create_data["description"] = data.description - new_mapping = await prisma_client.db.litellm_jwtkeymapping.create( + new_mapping = await JWTKeyMappingRepository(prisma_client).table.create( data=create_data ) @@ -113,7 +114,7 @@ async def update_jwt_key_mapping( try: # Get old mapping for cache invalidation - old_mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique( + old_mapping = await JWTKeyMappingRepository(prisma_client).table.find_unique( where={"id": data.id} ) @@ -123,7 +124,7 @@ async def update_jwt_key_mapping( cache_key = f"jwt_key_mapping:{old_mapping.jwt_claim_name}:{old_mapping.jwt_claim_value}" await user_api_key_cache.async_delete_cache(cache_key) - updated_mapping = await prisma_client.db.litellm_jwtkeymapping.update( + updated_mapping = await JWTKeyMappingRepository(prisma_client).table.update( where={"id": data.id}, data=update_data ) @@ -166,7 +167,7 @@ async def delete_jwt_key_mapping( try: # Get old mapping for cache invalidation - old_mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique( + old_mapping = await JWTKeyMappingRepository(prisma_client).table.find_unique( where={"id": data.id} ) @@ -176,7 +177,7 @@ async def delete_jwt_key_mapping( cache_key = f"jwt_key_mapping:{old_mapping.jwt_claim_name}:{old_mapping.jwt_claim_value}" await user_api_key_cache.async_delete_cache(cache_key) - await prisma_client.db.litellm_jwtkeymapping.delete(where={"id": data.id}) + await JWTKeyMappingRepository(prisma_client).table.delete(where={"id": data.id}) return {"status": "success"} except HTTPException: raise @@ -206,12 +207,12 @@ async def list_jwt_key_mappings( try: skip = (page - 1) * size - mappings = await prisma_client.db.litellm_jwtkeymapping.find_many( + mappings = await JWTKeyMappingRepository(prisma_client).table.find_many( skip=skip, take=size, order={"created_at": "desc"}, ) - total_count = await prisma_client.db.litellm_jwtkeymapping.count() + total_count = await JWTKeyMappingRepository(prisma_client).table.count() return { "mappings": [_to_response(m) for m in mappings], "total_count": total_count, @@ -245,7 +246,7 @@ async def info_jwt_key_mapping( raise HTTPException(status_code=500, detail="Database not connected") try: - mapping = await prisma_client.db.litellm_jwtkeymapping.find_unique( + mapping = await JWTKeyMappingRepository(prisma_client).table.find_unique( where={"id": id} ) if mapping is None: diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 99d0bac88a..8f606fdf90 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -28,7 +28,6 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s import litellm from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid -from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.constants import ( LENGTH_OF_LITELLM_GENERATED_KEY, LITELLM_PROXY_ADMIN_NAME, @@ -58,6 +57,7 @@ from litellm.proxy.common_utils.callback_utils import ( ) from litellm.proxy.common_utils.rbac_utils import check_org_admin_can_generate_keys from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time +from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks from litellm.proxy.management_endpoints.common_utils import ( _check_passthrough_routes_caller_permission, @@ -91,6 +91,19 @@ from litellm.proxy.utils import ( handle_exception_on_proxy, is_valid_api_key, ) +from litellm.repositories.budget_repository import BudgetRepository +from litellm.repositories.config_repository import ConfigRepository +from litellm.repositories.credentials_repository import CredentialsRepository +from litellm.repositories.model_repository import ModelRepository +from litellm.repositories.table_repositories import ( + DeletedVerificationTokenRepository, + DeprecatedVerificationTokenRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.router import Router from litellm.secret_managers.main import get_secret from litellm.types.proxy.management_endpoints.key_management_endpoints import ( @@ -582,7 +595,7 @@ async def validate_team_id_used_in_service_account_request( ) # check if team_id exists in the database - team = await prisma_client.db.litellm_teamtable.find_unique( + team = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id}, ) if team is None: @@ -774,7 +787,7 @@ async def _common_key_generation_helper( # noqa: PLR0915 ) new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True)) - _budget = await prisma_client.db.litellm_budgettable.create( + _budget = await BudgetRepository(prisma_client).table.create( data={ **new_budget, # type: ignore "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, @@ -1144,7 +1157,7 @@ async def _check_team_key_limits( # calculate allocated tpm/rpm limit # check if specified tpm/rpm limit is greater than allocated tpm/rpm limit - keys = await prisma_client.db.litellm_verificationtoken.find_many( + keys = await VerificationTokenRepository(prisma_client).table.find_many( where={"team_id": team_table.team_id}, ) # Exclude the key being updated to avoid double-counting its limits. @@ -1294,7 +1307,7 @@ async def _validate_caller_can_assign_key_org( detail="Cannot assign a key to an organization without a user_id on the caller's token", ) - user_row = await prisma_client.db.litellm_usertable.find_unique( + user_row = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_api_key_dict.user_id}, include={"organization_memberships": True}, ) @@ -1339,7 +1352,7 @@ async def _check_org_key_limits( # get all organization keys # calculate allocated tpm/rpm limit # check if specified tpm/rpm limit is greater than allocated tpm/rpm limit - keys = await prisma_client.db.litellm_verificationtoken.find_many( + keys = await VerificationTokenRepository(prisma_client).table.find_many( where={"organization_id": org_table.organization_id}, ) # Exclude the key being updated to avoid double-counting its limits. @@ -1968,9 +1981,9 @@ async def _get_and_validate_existing_key( hashed_token = _hash_token_if_needed(token=token) - existing_key_row = await prisma_client.db.litellm_verificationtoken.find_unique( - where={"token": hashed_token} - ) + existing_key_row = await VerificationTokenRepository( + prisma_client + ).table.find_unique(where={"token": hashed_token}) if existing_key_row is None: raise ProxyException( @@ -2869,7 +2882,9 @@ async def bulk_update_team_keys( # `blocked` is Boolean? with no default; `/key/generate` writes NULL. Prisma's `NOT` # excludes NULLs, so explicitly OR `false` with `null` to include them. now = datetime.now(timezone.utc) - existing_keys = await prisma_client.db.litellm_verificationtoken.find_many( + existing_keys = await VerificationTokenRepository( + prisma_client + ).table.find_many( where={ "team_id": data.team_id, "AND": [ @@ -2907,7 +2922,9 @@ async def bulk_update_team_keys( seen_hashes.add(h) requested_tokens.append(k) hashed_key_ids.append(h) - existing_keys = await prisma_client.db.litellm_verificationtoken.find_many( + existing_keys = await VerificationTokenRepository( + prisma_client + ).table.find_many( where={"team_id": data.team_id, "token": {"in": hashed_key_ids}} ) @@ -3232,7 +3249,9 @@ async def info_key_fn_v2( # Resolve key_aliases to tokens so we never pass token=None (unbounded query) tokens_to_query = list(data.keys) if data.keys else [] if data.key_aliases: - alias_rows = await prisma_client.db.litellm_verificationtoken.find_many( + alias_rows = await VerificationTokenRepository( + prisma_client + ).table.find_many( where={"key_alias": {"in": data.key_aliases}}, include={"litellm_budget_table": True}, ) @@ -3311,7 +3330,7 @@ async def info_key_fn( hashed_key: Optional[str] = key if key is not None: hashed_key = _hash_token_if_needed(token=key) - key_info = await prisma_client.db.litellm_verificationtoken.find_unique( + key_info = await VerificationTokenRepository(prisma_client).table.find_unique( where={"token": hashed_key}, # type: ignore include={"litellm_budget_table": True}, ) @@ -3851,7 +3870,7 @@ async def delete_verification_tokens( if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] _keys_being_deleted: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( + await VerificationTokenRepository(prisma_client).table.find_many( where={"token": {"in": tokens}} ) ) @@ -3989,7 +4008,9 @@ async def _save_deleted_verification_token_records( """Save deleted verification token records to the database.""" if not records: return - await prisma_client.db.litellm_deletedverificationtoken.create_many(data=records) + await DeletedVerificationTokenRepository(prisma_client).table.create_many( + data=records + ) async def _persist_deleted_verification_tokens( @@ -4017,9 +4038,9 @@ async def delete_key_aliases( user_api_key_dict: UserAPIKeyAuth, litellm_changed_by: Optional[str] = None, ) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: - _keys_being_deleted = await prisma_client.db.litellm_verificationtoken.find_many( - where={"key_alias": {"in": key_aliases}} - ) + _keys_being_deleted = await VerificationTokenRepository( + prisma_client + ).table.find_many(where={"key_alias": {"in": key_aliases}}) tokens = [key.token for key in _keys_being_deleted] return await delete_verification_tokens( @@ -4054,9 +4075,7 @@ async def _rotate_master_key( # noqa: PLR0915 from litellm.proxy.proxy_server import proxy_config try: - models: Optional[List] = ( - await prisma_client.db.litellm_proxymodeltable.find_many() - ) + models: Optional[List] = await ModelRepository(prisma_client).table.find_many() except Exception: models = None # 2. process model table @@ -4088,7 +4107,7 @@ async def _rotate_master_key( # noqa: PLR0915 ) # 3. process config table try: - config = await prisma_client.db.litellm_config.find_many() + config = await ConfigRepository(prisma_client).table.find_many() except Exception: config = None @@ -4109,7 +4128,7 @@ async def _rotate_master_key( # noqa: PLR0915 ) if encrypted_env_vars: - await prisma_client.db.litellm_config.update( + await ConfigRepository(prisma_client).table.update( where={"param_name": "environment_variables"}, data={"param_value": prisma.Json(encrypted_env_vars)}, # type: ignore[attr-defined] ) @@ -4148,7 +4167,7 @@ async def _rotate_master_key( # noqa: PLR0915 # 5. process credentials table try: - credentials = await prisma_client.db.litellm_credentialstable.find_many() + credentials = await CredentialsRepository(prisma_client).table.find_many() except Exception: credentials = None if credentials: @@ -4171,7 +4190,7 @@ async def _rotate_master_key( # noqa: PLR0915 _cred_data["credential_info"] = prisma.Json( # type: ignore[attr-defined] _cred_data["credential_info"] ) - await prisma_client.db.litellm_credentialstable.update( + await CredentialsRepository(prisma_client).table.update( where={"credential_name": cred.credential_name}, data={ **_cred_data, @@ -4243,7 +4262,7 @@ async def _insert_deprecated_key( try: revoke_at = datetime.now(timezone.utc) + timedelta(seconds=grace_seconds) - await prisma_client.db.litellm_deprecatedverificationtoken.upsert( + await DeprecatedVerificationTokenRepository(prisma_client).table.upsert( where={"token": old_token_hash}, data={ "create": { @@ -4335,7 +4354,7 @@ async def _execute_virtual_key_regeneration( grace_period=data.grace_period if data else None, ) - updated_token = await prisma_client.db.litellm_verificationtoken.update( + updated_token = await VerificationTokenRepository(prisma_client).table.update( where={"token": hashed_api_key}, data=update_data, # type: ignore ) @@ -4530,7 +4549,7 @@ async def regenerate_key_fn( # noqa: PLR0915 else: hashed_api_key = hash_token(key) - _key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique( + _key_in_db = await VerificationTokenRepository(prisma_client).table.find_unique( where={"token": hashed_api_key}, ) if _key_in_db is None: @@ -4719,7 +4738,7 @@ async def reset_key_spend_fn( else: hashed_api_key = hash_token(key) - _key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique( + _key_in_db = await VerificationTokenRepository(prisma_client).table.find_unique( where={"token": hashed_api_key}, include={"litellm_budget_table": True}, ) @@ -4739,7 +4758,7 @@ async def reset_key_spend_fn( user_api_key_cache=user_api_key_cache, ) - updated_key = await prisma_client.db.litellm_verificationtoken.update( + updated_key = await VerificationTokenRepository(prisma_client).table.update( where={"token": hashed_api_key}, data={"spend": reset_to}, ) @@ -4792,11 +4811,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[BaseModel] = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, - ) + complete_user_info_db_obj: Optional[BaseModel] = await UserRepository( + prisma_client + ).table.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, ) if complete_user_info_db_obj is None: @@ -4846,7 +4865,9 @@ async def validate_key_list_check( if key_hash: try: - key_info = await prisma_client.db.litellm_verificationtoken.find_unique( + key_info = await VerificationTokenRepository( + prisma_client + ).table.find_unique( where={"token": key_hash}, ) except Exception: @@ -4879,11 +4900,9 @@ async def _fetch_user_team_objects( if complete_user_info is None or not complete_user_info.teams: return [] - teams: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} - ) - ) + teams: Optional[List[BaseModel]] = await TeamRepository( + prisma_client + ).table.find_many(where={"team_id": {"in": complete_user_info.teams}}) if teams is None: return [] @@ -5160,7 +5179,7 @@ async def _apply_non_admin_alias_scope( # Look up the user's teams from the user table user_teams: List[str] = [] if user_api_key_dict.user_id: - user_row = await prisma_client.db.litellm_usertable.find_unique( + user_row = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_api_key_dict.user_id} ) if user_row is not None: @@ -5548,7 +5567,7 @@ async def _list_key_helper( # Fetch keys with pagination if use_deleted_table: - keys = await prisma_client.db.litellm_deletedverificationtoken.find_many( + keys = await DeletedVerificationTokenRepository(prisma_client).table.find_many( where=where, # type: ignore skip=skip, # type: ignore take=size, # type: ignore @@ -5562,7 +5581,7 @@ async def _list_key_helper( ), ) else: - keys = await prisma_client.db.litellm_verificationtoken.find_many( + keys = await VerificationTokenRepository(prisma_client).table.find_many( where=where, # type: ignore skip=skip, # type: ignore take=size, # type: ignore @@ -5581,11 +5600,13 @@ async def _list_key_helper( # Get total count of keys if use_deleted_table: - total_count = await prisma_client.db.litellm_deletedverificationtoken.count( + total_count = await DeletedVerificationTokenRepository( + prisma_client + ).table.count( where=where # type: ignore ) else: - total_count = await prisma_client.db.litellm_verificationtoken.count( + total_count = await VerificationTokenRepository(prisma_client).table.count( where=where # type: ignore ) @@ -5601,7 +5622,7 @@ async def _list_key_helper( created_by_ids = [key.created_by for key in keys if key.created_by] all_ids = list(set(user_ids + created_by_ids)) # Remove duplicates if all_ids: - users = await prisma_client.db.litellm_usertable.find_many( + users = await UserRepository(prisma_client).table.find_many( where={"user_id": {"in": all_ids}} ) user_map = {user.user_id: user for user in users} @@ -5688,7 +5709,7 @@ async def _check_key_admin_access( return # Look up the target key to find its team - target_key_row = await prisma_client.db.litellm_verificationtoken.find_unique( + target_key_row = await VerificationTokenRepository(prisma_client).table.find_unique( where={"token": hashed_token} ) if target_key_row is None: @@ -5755,6 +5776,9 @@ async def block_key( Note: This is an admin-only endpoint. Only proxy admins, team admins, or org admins can block keys. """ + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) from litellm.proxy.proxy_server import ( create_audit_log_for_update, hash_token, @@ -5763,9 +5787,6 @@ async def block_key( proxy_logging_obj, user_api_key_cache, ) - from litellm.proxy.management_helpers.audit_logs import ( - get_audit_log_changed_by, - ) if prisma_client is None: raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value)) @@ -5792,9 +5813,9 @@ async def block_key( ) # Check if the key exists before trying to block it - existing_record = await prisma_client.db.litellm_verificationtoken.find_unique( - where={"token": hashed_token} - ) + existing_record = await VerificationTokenRepository( + prisma_client + ).table.find_unique(where={"token": hashed_token}) if existing_record is None: raise ProxyException( message="Key not found.", @@ -5824,7 +5845,7 @@ async def block_key( ) ) - record = await prisma_client.db.litellm_verificationtoken.update( + record = await VerificationTokenRepository(prisma_client).table.update( where={"token": hashed_token}, data={"blocked": True} # type: ignore ) @@ -5869,6 +5890,9 @@ async def unblock_key( Note: This is an admin-only endpoint. Only proxy admins, team admins, or org admins can unblock keys. """ + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) from litellm.proxy.proxy_server import ( create_audit_log_for_update, hash_token, @@ -5877,9 +5901,6 @@ async def unblock_key( proxy_logging_obj, user_api_key_cache, ) - from litellm.proxy.management_helpers.audit_logs import ( - get_audit_log_changed_by, - ) if prisma_client is None: raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value)) @@ -5906,9 +5927,9 @@ async def unblock_key( ) # Check if the key exists before trying to unblock it - existing_record = await prisma_client.db.litellm_verificationtoken.find_unique( - where={"token": hashed_token} - ) + existing_record = await VerificationTokenRepository( + prisma_client + ).table.find_unique(where={"token": hashed_token}) if existing_record is None: raise ProxyException( message="Key not found.", @@ -5938,7 +5959,7 @@ async def unblock_key( ) ) - record = await prisma_client.db.litellm_verificationtoken.update( + record = await VerificationTokenRepository(prisma_client).table.update( where={"token": hashed_token}, data={"blocked": False} # type: ignore ) @@ -6202,9 +6223,9 @@ async def _enforce_unique_key_alias( # Exclude the current key from the uniqueness check where_clause["NOT"] = {"token": existing_key_token} - existing_key = await prisma_client.db.litellm_verificationtoken.find_first( - where=where_clause - ) + existing_key = await VerificationTokenRepository( + prisma_client + ).table.find_first(where=where_clause) if existing_key is not None: raise ProxyException( message=f"Key with alias '{key_alias}' already exists. Unique key aliases across all keys are required.", diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index 64e89ee40f..c6c14c7a3e 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -60,6 +60,10 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import ( encrypt_value_helper, ) from litellm.proxy.management_helpers.audit_logs import get_audit_log_changed_by +from litellm.repositories.table_repositories import ( + MCPServerRepository, + MCPUserCredentialsRepository, +) router = APIRouter(prefix="/v1/mcp", tags=["mcp"]) @@ -761,7 +765,7 @@ if MCP_AVAILABLE: # Get from DB if prisma_client is not None: try: - mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many() + mcp_servers = await MCPServerRepository(prisma_client).table.find_many() for server in mcp_servers: if ( hasattr(server, "mcp_access_groups") @@ -998,10 +1002,10 @@ if MCP_AVAILABLE: if getattr(s, "is_byok", False) ] if byok_server_ids: - cred_rows = ( - await _byok_prisma_client.db.litellm_mcpusercredentials.find_many( - where={"user_id": user_id, "server_id": {"in": byok_server_ids}} - ) + cred_rows = await MCPUserCredentialsRepository( + _byok_prisma_client + ).table.find_many( + where={"user_id": user_id, "server_id": {"in": byok_server_ids}} ) cred_set = {r.server_id for r in cred_rows} for server in redacted_mcp_servers: diff --git a/litellm/proxy/management_endpoints/model_access_group_management_endpoints.py b/litellm/proxy/management_endpoints/model_access_group_management_endpoints.py index b05cfef576..a8551f6333 100644 --- a/litellm/proxy/management_endpoints/model_access_group_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_access_group_management_endpoints.py @@ -19,6 +19,7 @@ from litellm.proxy.management_endpoints.model_management_endpoints import ( clear_cache, ) from litellm.proxy.utils import PrismaClient +from litellm.repositories.model_repository import ModelRepository from litellm.types.proxy.management_endpoints.model_management_endpoints import ( AccessGroupInfo, DeleteModelGroupResponse, @@ -95,7 +96,7 @@ async def update_deployments_with_access_group( verbose_proxy_logger.debug(f"Updating deployments for model_name: {model_name}") # Get all deployments with this model_name - deployments = await prisma_client.db.litellm_proxymodeltable.find_many( + deployments = await ModelRepository(prisma_client).table.find_many( where={"model_name": model_name} ) @@ -124,7 +125,7 @@ async def update_deployments_with_access_group( # Only update in DB if modified if was_modified: - await prisma_client.db.litellm_proxymodeltable.update( + await ModelRepository(prisma_client).table.update( where={"model_id": deployment.model_id}, data={"model_info": json.dumps(updated_model_info)}, ) @@ -152,7 +153,7 @@ async def update_specific_deployments_with_access_group( models_updated = 0 for model_id in model_ids: verbose_proxy_logger.debug(f"Updating specific deployment model_id: {model_id}") - deployment = await prisma_client.db.litellm_proxymodeltable.find_unique( + deployment = await ModelRepository(prisma_client).table.find_unique( where={"model_id": model_id} ) if deployment is None: @@ -168,7 +169,7 @@ async def update_specific_deployments_with_access_group( access_group=access_group, ) if was_modified: - await prisma_client.db.litellm_proxymodeltable.update( + await ModelRepository(prisma_client).table.update( where={"model_id": model_id}, data={"model_info": json.dumps(updated_model_info)}, ) @@ -215,7 +216,7 @@ async def get_all_access_groups_from_db( Dict[str, AccessGroupInfo]: Dictionary mapping access_group name to info """ # Get all deployments - deployments = await prisma_client.db.litellm_proxymodeltable.find_many() + deployments = await ModelRepository(prisma_client).table.find_many() # Build access group map access_group_map: Dict[str, Dict[str, Any]] = {} @@ -604,7 +605,7 @@ async def update_access_group( try: # Step 1: Remove access group from ALL DB deployments (skip config models) - all_deployments = await prisma_client.db.litellm_proxymodeltable.find_many() + all_deployments = await ModelRepository(prisma_client).table.find_many() for deployment in all_deployments: model_info = deployment.model_info or {} @@ -615,7 +616,7 @@ async def update_access_group( ) if was_modified: - await prisma_client.db.litellm_proxymodeltable.update( + await ModelRepository(prisma_client).table.update( where={"model_id": deployment.model_id}, data={"model_info": json.dumps(updated_model_info)}, ) @@ -722,7 +723,7 @@ async def delete_access_group( try: # Remove access group from all DB deployments (skip config models) - all_deployments = await prisma_client.db.litellm_proxymodeltable.find_many() + all_deployments = await ModelRepository(prisma_client).table.find_many() models_updated = 0 for deployment in all_deployments: @@ -734,7 +735,7 @@ async def delete_access_group( ) if was_modified: - await prisma_client.db.litellm_proxymodeltable.update( + await ModelRepository(prisma_client).table.update( where={"model_id": deployment.model_id}, data={"model_info": json.dumps(updated_model_info)}, ) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index e4ecda3fe3..566ef84533 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -48,6 +48,9 @@ from litellm.proxy.management_endpoints.team_endpoints import ( ) from litellm.proxy.management_helpers.audit_logs import create_object_audit_log from litellm.proxy.utils import PrismaClient +from litellm.repositories.model_repository import ModelRepository +from litellm.repositories.table_repositories import ModelTableRepository +from litellm.repositories.team_repository import TeamRepository from litellm.types.proxy.management_endpoints.model_management_endpoints import ( UpdateUsefulLinksRequest, ) @@ -86,7 +89,7 @@ async def get_db_model( ) -> Optional[Deployment]: db_model = cast( Optional[BaseModel], - await prisma_client.db.litellm_proxymodeltable.find_unique( + await ModelRepository(prisma_client).table.find_unique( where={"model_id": model_id} ), ) @@ -290,7 +293,7 @@ async def patch_model( update_data["updated_at"] = cast(str, get_utc_datetime()) # Perform partial update - updated_model = await prisma_client.db.litellm_proxymodeltable.update( + updated_model = await ModelRepository(prisma_client).table.update( where={"model_id": model_id}, data=update_data, ) @@ -362,7 +365,7 @@ async def _add_model_to_db( if model_params.model_info.id is not None: _data["model_id"] = model_params.model_info.id if should_create_model_in_db: - model_response = await prisma_client.db.litellm_proxymodeltable.create( + model_response = await ModelRepository(prisma_client).table.create( data=_data # type: ignore ) else: @@ -571,7 +574,7 @@ async def _get_team_deployments( team_id in model_info with Python-side filtering. """ prefix = f"model_name_{team_id}_" - response = await prisma_client.db.litellm_proxymodeltable.find_many( + response = await ModelRepository(prisma_client).table.find_many( where={ "model_name": {"startswith": prefix}, } @@ -828,7 +831,7 @@ class ModelManagementAuthChecks: detail={"error": CommonProxyErrors.not_premium_user.value}, ) - _existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + _existing_team_row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": model_params.model_info.team_id} ) @@ -863,7 +866,7 @@ class ModelManagementAuthChecks: model_params.model_info is not None and model_params.model_info.team_id is not None ): - team_obj_row = await prisma_client.db.litellm_teamtable.find_unique( + team_obj_row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": model_params.model_info.team_id} ) if team_obj_row is None: @@ -937,7 +940,7 @@ async def delete_model( }, ) - model_in_db = await prisma_client.db.litellm_proxymodeltable.find_unique( + model_in_db = await ModelRepository(prisma_client).table.find_unique( where={"model_id": model_info.id} ) if model_in_db is None: @@ -961,7 +964,7 @@ async def delete_model( - store keys separately """ # encrypt litellm params # - result = await prisma_client.db.litellm_proxymodeltable.delete( + result = await ModelRepository(prisma_client).table.delete( where={"model_id": model_info.id} ) @@ -1039,7 +1042,7 @@ async def delete_team_model_alias( Returns: - List of team id + model alias pairs that were removed """ - team_model_aliases = await prisma_client.db.litellm_modeltable.find_many( + team_model_aliases = await ModelTableRepository(prisma_client).table.find_many( include={"team": True} ) tasks = [] @@ -1056,7 +1059,7 @@ async def delete_team_model_alias( removed_model_aliases.append((team_model_alias.team.team_id, key)) del model_aliases[key] tasks.append( - prisma_client.db.litellm_modeltable.update( + ModelTableRepository(prisma_client).table.update( where={"id": id}, data={"model_aliases": json.dumps(model_aliases)}, ) @@ -1275,11 +1278,9 @@ async def update_model( if _model_id is None: raise Exception("model_info.id not provided") - _existing_litellm_params = ( - await prisma_client.db.litellm_proxymodeltable.find_unique( - where={"model_id": _model_id} - ) - ) + _existing_litellm_params = await ModelRepository( + prisma_client + ).table.find_unique(where={"model_id": _model_id}) if _existing_litellm_params is None: if ( @@ -1340,7 +1341,7 @@ async def update_model( "litellm_params": json.dumps(merged_dictionary), # type: ignore "updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME, } - model_response = await prisma_client.db.litellm_proxymodeltable.update( + model_response = await ModelRepository(prisma_client).table.update( where={"model_id": _model_id}, data=_data, # type: ignore ) diff --git a/litellm/proxy/management_endpoints/organization_endpoints.py b/litellm/proxy/management_endpoints/organization_endpoints.py index 4d4ed53aaa..99659121b2 100644 --- a/litellm/proxy/management_endpoints/organization_endpoints.py +++ b/litellm/proxy/management_endpoints/organization_endpoints.py @@ -40,6 +40,15 @@ from litellm.proxy.management_helpers.utils import ( management_endpoint_wrapper, ) from litellm.proxy.utils import PrismaClient +from litellm.repositories.budget_repository import BudgetRepository +from litellm.repositories.object_permission_repository import ObjectPermissionRepository +from litellm.repositories.organization_repository import OrganizationRepository +from litellm.repositories.table_repositories import OrganizationMembershipRepository +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.types.proxy.management_endpoints.common_daily_activity import ( SpendAnalyticsPaginatedResponse, ) @@ -245,7 +254,7 @@ async def new_organization( if user_api_key_dict.user_id is not None: try: - user_object = await prisma_client.db.litellm_usertable.find_unique( + user_object = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_api_key_dict.user_id} ) user_object_correct_type = LiteLLM_UserTable(**user_object.model_dump()) @@ -267,7 +276,7 @@ async def new_organization( new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True)) - _budget = await prisma_client.db.litellm_budgettable.create( + _budget = await BudgetRepository(prisma_client).table.create( data={ **new_budget, # type: ignore "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, @@ -323,7 +332,7 @@ async def new_organization( verbose_proxy_logger.info( f"new_organization_row: {json.dumps(new_organization_row, indent=2)}" ) - response = await prisma_client.db.litellm_organizationtable.create( + response = await OrganizationRepository(prisma_client).table.create( data={ **new_organization_row, # type: ignore }, @@ -372,9 +381,9 @@ async def get_organization_daily_activity( # Restrict non-proxy-admins to only organizations where they are org_admin if not _user_has_admin_view(user_api_key_dict): - memberships = await prisma_client.db.litellm_organizationmembership.find_many( - where={"user_id": user_api_key_dict.user_id} - ) + memberships = await OrganizationMembershipRepository( + prisma_client + ).table.find_many(where={"user_id": user_api_key_dict.user_id}) admin_org_ids = [ m.organization_id for m in memberships @@ -400,7 +409,7 @@ async def get_organization_daily_activity( where_condition = {} if org_ids_list: where_condition["organization_id"] = {"in": list(org_ids_list)} - org_aliases = await prisma_client.db.litellm_organizationtable.find_many( + org_aliases = await OrganizationRepository(prisma_client).table.find_many( where=where_condition ) org_alias_metadata = { @@ -439,10 +448,10 @@ async def _set_object_permission( return None if data.object_permission is not None: - created_object_permission = ( - await prisma_client.db.litellm_objectpermissiontable.create( - data=data.object_permission.model_dump(exclude_none=True), - ) + created_object_permission = await ObjectPermissionRepository( + prisma_client + ).table.create( + data=data.object_permission.model_dump(exclude_none=True), ) del data.object_permission return created_object_permission.object_permission_id @@ -525,10 +534,10 @@ async def update_organization( prisma_client=prisma_client, ) - existing_organization_row = ( - await prisma_client.db.litellm_organizationtable.find_unique( - where={"organization_id": data.organization_id}, - ) + existing_organization_row = await OrganizationRepository( + prisma_client + ).table.find_unique( + where={"organization_id": data.organization_id}, ) if existing_organization_row is None: @@ -574,7 +583,7 @@ async def update_organization( for field in LiteLLM_BudgetTable.model_fields.keys(): updated_organization_row.pop(field, None) - response = await prisma_client.db.litellm_organizationtable.update( + response = await OrganizationRepository(prisma_client).table.update( where={"organization_id": data.organization_id}, data=updated_organization_row, include={"members": True, "teams": True, "litellm_budget_table": True}, @@ -644,19 +653,19 @@ async def delete_organization( deleted_orgs = [] for organization_id in data.organization_ids: # delete all teams in the organization - await prisma_client.db.litellm_teamtable.delete_many( + await TeamRepository(prisma_client).table.delete_many( where={"organization_id": organization_id} ) # delete all members in the organization - await prisma_client.db.litellm_organizationmembership.delete_many( + await OrganizationMembershipRepository(prisma_client).table.delete_many( where={"organization_id": organization_id} ) # delete all keys in the organization - await prisma_client.db.litellm_verificationtoken.delete_many( + await VerificationTokenRepository(prisma_client).table.delete_many( where={"organization_id": organization_id} ) # delete the organization - deleted_org = await prisma_client.db.litellm_organizationtable.delete( + deleted_org = await OrganizationRepository(prisma_client).table.delete( where={"organization_id": organization_id}, include={"members": True, "teams": True, "litellm_budget_table": True}, ) @@ -732,17 +741,15 @@ async def list_organization( # if proxy admin or admin viewer - get all orgs (with optional filters) if _user_has_admin_view(user_api_key_dict): - response = await prisma_client.db.litellm_organizationtable.find_many( + response = await OrganizationRepository(prisma_client).table.find_many( where=where_conditions if where_conditions else None, include={"litellm_budget_table": True, "members": True, "teams": True}, ) # if internal user - get orgs they are a member of (with optional filters) else: - org_memberships = ( - await prisma_client.db.litellm_organizationmembership.find_many( - where={"user_id": user_api_key_dict.user_id} - ) - ) + org_memberships = await OrganizationMembershipRepository( + prisma_client + ).table.find_many(where={"user_id": user_api_key_dict.user_id}) membership_org_ids = [ membership.organization_id for membership in org_memberships ] @@ -756,20 +763,20 @@ async def list_organization( response = [] else: where_conditions["organization_id"] = org_id - response = ( - await prisma_client.db.litellm_organizationtable.find_many( - where=where_conditions, - include={ - "litellm_budget_table": True, - "members": True, - "teams": True, - }, - ) + response = await OrganizationRepository( + prisma_client + ).table.find_many( + where=where_conditions, + include={ + "litellm_budget_table": True, + "members": True, + "teams": True, + }, ) else: # Filter by membership and any additional filters where_conditions["organization_id"] = {"in": membership_org_ids} - response = await prisma_client.db.litellm_organizationtable.find_many( + response = await OrganizationRepository(prisma_client).table.find_many( where=where_conditions, include={ "litellm_budget_table": True, @@ -809,20 +816,20 @@ async def info_organization( prisma_client=prisma_client, ) - response: Optional[LiteLLM_OrganizationTableWithMembers] = ( - await prisma_client.db.litellm_organizationtable.find_unique( - where={"organization_id": organization_id}, - include={ - "litellm_budget_table": True, - "members": { - "include": { - "user": True, - } - }, - "teams": True, - "object_permission": True, + response: Optional[ + LiteLLM_OrganizationTableWithMembers + ] = await OrganizationRepository(prisma_client).table.find_unique( + where={"organization_id": organization_id}, + include={ + "litellm_budget_table": True, + "members": { + "include": { + "user": True, + } }, - ) + "teams": True, + "object_permission": True, + }, ) if response is None: @@ -868,7 +875,7 @@ async def deprecated_info_organization( prisma_client=prisma_client, ) - response = await prisma_client.db.litellm_organizationtable.find_many( + response = await OrganizationRepository(prisma_client).table.find_many( where={"organization_id": {"in": data.organizations}}, include={"litellm_budget_table": True}, ) @@ -945,11 +952,9 @@ async def organization_member_add( ) # Check if organization exists - existing_organization_row = ( - await prisma_client.db.litellm_organizationtable.find_unique( - where={"organization_id": data.organization_id} - ) - ) + existing_organization_row = await OrganizationRepository( + prisma_client + ).table.find_unique(where={"organization_id": data.organization_id}) if existing_organization_row is None: raise HTTPException( status_code=404, @@ -1012,11 +1017,9 @@ async def find_member_if_email( """ try: - existing_user_email_row: BaseModel = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_email": user_email} - ) - ) + existing_user_email_row: BaseModel = await UserRepository( + prisma_client + ).table.find_unique(where={"user_email": user_email}) except Exception: raise HTTPException( status_code=400, @@ -1064,11 +1067,9 @@ async def organization_member_update( ) # Check if organization exists - existing_organization_row = ( - await prisma_client.db.litellm_organizationtable.find_unique( - where={"organization_id": data.organization_id} - ) - ) + existing_organization_row = await OrganizationRepository( + prisma_client + ).table.find_unique(where={"organization_id": data.organization_id}) if existing_organization_row is None: raise HTTPException( status_code=400, @@ -1085,15 +1086,15 @@ async def organization_member_update( data.user_id = existing_user_email_row.user_id try: - existing_organization_membership = ( - await prisma_client.db.litellm_organizationmembership.find_unique( - where={ - "user_id_organization_id": { - "user_id": data.user_id, - "organization_id": data.organization_id, - } + existing_organization_membership = await OrganizationMembershipRepository( + prisma_client + ).table.find_unique( + where={ + "user_id_organization_id": { + "user_id": data.user_id, + "organization_id": data.organization_id, } - ) + } ) except Exception as e: raise HTTPException( @@ -1114,7 +1115,7 @@ async def organization_member_update( # org-scoped operations. An org-admin of any org could otherwise # alter a PROXY_ADMIN user's per-org role, which has downstream # effects on admin UI filtering and scope derivation. - target_user_row = await prisma_client.db.litellm_usertable.find_unique( + target_user_row = await UserRepository(prisma_client).table.find_unique( where={"user_id": data.user_id} ) if target_user_row is not None and getattr( @@ -1136,7 +1137,7 @@ async def organization_member_update( # Update member role if data.role is not None: - await prisma_client.db.litellm_organizationmembership.update( + await OrganizationMembershipRepository(prisma_client).table.update( where={ "user_id_organization_id": { "user_id": data.user_id, @@ -1165,7 +1166,7 @@ async def organization_member_update( ) # update organization membership with new budget_id - await prisma_client.db.litellm_organizationmembership.update( + await OrganizationMembershipRepository(prisma_client).table.update( where={ "user_id_organization_id": { "user_id": data.user_id, @@ -1174,16 +1175,16 @@ async def organization_member_update( }, data={"budget_id": budget_id}, ) - final_organization_membership: Optional[BaseModel] = ( - await prisma_client.db.litellm_organizationmembership.find_unique( - where={ - "user_id_organization_id": { - "user_id": data.user_id, - "organization_id": data.organization_id, - } - }, - include={"litellm_budget_table": True}, - ) + final_organization_membership: Optional[ + BaseModel + ] = await OrganizationMembershipRepository(prisma_client).table.find_unique( + where={ + "user_id_organization_id": { + "user_id": data.user_id, + "organization_id": data.organization_id, + } + }, + include={"litellm_budget_table": True}, ) if final_organization_membership is None: @@ -1239,7 +1240,9 @@ async def organization_member_delete( ) data.user_id = existing_user_email_row.user_id - member_to_delete = await prisma_client.db.litellm_organizationmembership.delete( + member_to_delete = await OrganizationMembershipRepository( + prisma_client + ).table.delete( where={ "user_id_organization_id": { "user_id": data.user_id, @@ -1273,17 +1276,15 @@ async def add_member_to_organization( existing_user_email_row = None ## Check if user exists in LiteLLM_UserTable - user exists - either the user_id or user_email is in LiteLLM_UserTable if member.user_id is not None: - existing_user_id_row = await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": member.user_id} - ) + existing_user_id_row = await UserRepository( + prisma_client + ).table.find_unique(where={"user_id": member.user_id}) if existing_user_id_row is None and member.user_email is not None: try: - existing_user_email_row = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_email": member.user_email} - ) - ) + existing_user_email_row = await UserRepository( + prisma_client + ).table.find_unique(where={"user_email": member.user_email}) except Exception as e: raise ValueError( f"Potential NON-Existent or Duplicate user email in DB: Error finding a unique instance of user_email={member.user_email} in LiteLLM_UserTable.: {e}" @@ -1326,14 +1327,14 @@ async def add_member_to_organization( ) # Add user to organization - _organization_membership = ( - await prisma_client.db.litellm_organizationmembership.create( - data={ - "organization_id": organization_id, - "user_id": user_object.user_id, - "user_role": member.role, - } - ) + _organization_membership = await OrganizationMembershipRepository( + prisma_client + ).table.create( + data={ + "organization_id": organization_id, + "user_id": user_object.user_id, + "user_role": member.role, + } ) organization_membership = LiteLLM_OrganizationMembershipTable( **_organization_membership.model_dump() diff --git a/litellm/proxy/management_endpoints/scim/scim_transformations.py b/litellm/proxy/management_endpoints/scim/scim_transformations.py index 28fb87d9b3..d1e00f87b6 100644 --- a/litellm/proxy/management_endpoints/scim/scim_transformations.py +++ b/litellm/proxy/management_endpoints/scim/scim_transformations.py @@ -6,6 +6,7 @@ from litellm.proxy._types import ( Member, NewUserResponse, ) +from litellm.repositories.team_repository import TeamRepository from litellm.types.proxy.management_endpoints.scim_v2 import * @@ -29,7 +30,7 @@ class ScimTransformations: # Get user's teams/groups groups = [] for team_id in user.teams or []: - team = await prisma_client.db.litellm_teamtable.find_unique( + team = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) if team: diff --git a/litellm/proxy/management_endpoints/scim/scim_v2.py b/litellm/proxy/management_endpoints/scim/scim_v2.py index 1f20764f83..0798d1a510 100644 --- a/litellm/proxy/management_endpoints/scim/scim_v2.py +++ b/litellm/proxy/management_endpoints/scim/scim_v2.py @@ -22,7 +22,6 @@ from typing_extensions import TypedDict import litellm from litellm._logging import verbose_proxy_logger -from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers from litellm._uuid import uuid from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import ( @@ -41,6 +40,7 @@ from litellm.proxy._types import ( ) from litellm.proxy.auth.auth_checks import _delete_cache_key_object from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers from litellm.proxy.management_endpoints.internal_user_endpoints import new_user from litellm.proxy.management_endpoints.scim.scim_transformations import ( ScimTransformations, @@ -51,6 +51,16 @@ from litellm.proxy.management_endpoints.team_endpoints import ( team_member_delete, ) from litellm.proxy.utils import _premium_user_check, handle_exception_on_proxy +from litellm.repositories.table_repositories import ( + InvitationLinkRepository, + OrganizationMembershipRepository, + TeamMembershipRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.types.proxy.management_endpoints.scim_v2 import * @@ -74,7 +84,7 @@ class UserProvisionerHelpers: if not new_user_request.user_email: return None - existing_user = await prisma_client.db.litellm_usertable.find_first( + existing_user = await UserRepository(prisma_client).table.find_first( where={"user_email": new_user_request.user_email} ) @@ -82,7 +92,7 @@ class UserProvisionerHelpers: return None # Update the user - updated_user = await prisma_client.db.litellm_usertable.update( + updated_user = await UserRepository(prisma_client).table.update( where={"user_id": existing_user.user_id}, data={ "user_id": new_user_request.user_id, @@ -139,7 +149,7 @@ async def _check_user_exists(user_id: str): """Check if user exists and return user, raise 404 if not found.""" prisma_client = await _get_prisma_client_or_raise_exception() - user = await prisma_client.db.litellm_usertable.find_unique( + user = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_id} ) @@ -155,7 +165,7 @@ async def _check_team_exists(team_id: str): """Check if team exists and return team, raise 404 if not found.""" prisma_client = await _get_prisma_client_or_raise_exception() - team = await prisma_client.db.litellm_teamtable.find_unique( + team = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) @@ -268,7 +278,7 @@ async def _extract_group_member_ids(group: SCIMGroup) -> GroupMemberExtractionRe ) # Check if user exists - user = await prisma_client.db.litellm_usertable.find_unique( + user = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_id} ) @@ -310,7 +320,7 @@ async def _get_team_members_display(member_ids: List[str]) -> List[SCIMMember]: members: List[SCIMMember] = [] for member_id in member_ids: - user = await prisma_client.db.litellm_usertable.find_unique( + user = await UserRepository(prisma_client).table.find_unique( where={"user_id": member_id} ) if user: @@ -367,7 +377,7 @@ async def _set_user_keys_blocked(user_id: str, blocked: bool) -> int: # `blocked` is a nullable column with no default, so existing rows # typically hold NULL; treat NULL as "not blocked" since SQL equality # on NULL would otherwise silently skip them. - candidates = await prisma_client.db.litellm_verificationtoken.find_many( + candidates = await VerificationTokenRepository(prisma_client).table.find_many( where={ "user_id": user_id, "OR": [{"blocked": False}, {"blocked": None}], @@ -375,7 +385,7 @@ async def _set_user_keys_blocked(user_id: str, blocked: bool) -> int: ) affected_keys = candidates else: - candidates = await prisma_client.db.litellm_verificationtoken.find_many( + candidates = await VerificationTokenRepository(prisma_client).table.find_many( where={"user_id": user_id, "blocked": True}, ) affected_keys = [k for k in candidates if _key_was_scim_blocked(k.metadata)] @@ -395,7 +405,7 @@ async def _set_user_keys_blocked(user_id: str, blocked: bool) -> int: for k, v in current_metadata.items() if k != SCIM_BLOCKED_METADATA_KEY } - await prisma_client.db.litellm_verificationtoken.update( + await VerificationTokenRepository(prisma_client).table.update( where={"token": key_row.token}, data={"blocked": blocked, "metadata": safe_dumps(new_metadata)}, ) @@ -423,7 +433,7 @@ async def _delete_rows_referencing_user(prisma_client: Any, *, user_id: str) -> the user delete with an FK constraint violation (e.g. ``LiteLLM_InvitationLink_user_id_fkey``). """ - await prisma_client.db.litellm_invitationlink.delete_many( + await InvitationLinkRepository(prisma_client).table.delete_many( where={ "OR": [ {"user_id": user_id}, @@ -432,10 +442,10 @@ async def _delete_rows_referencing_user(prisma_client: Any, *, user_id: str) -> ] } ) - await prisma_client.db.litellm_organizationmembership.delete_many( + await OrganizationMembershipRepository(prisma_client).table.delete_many( where={"user_id": user_id} ) - await prisma_client.db.litellm_teammembership.delete_many( + await TeamMembershipRepository(prisma_client).table.delete_many( where={"user_id": user_id} ) @@ -897,17 +907,17 @@ async def get_users( where_conditions["user_email"] = filter_value # Get users from database - users: List[LiteLLM_UserTable] = ( - await prisma_client.db.litellm_usertable.find_many( - where=where_conditions, - skip=(startIndex - 1), - take=count, - order={"created_at": "desc"}, - ) + users: List[LiteLLM_UserTable] = await UserRepository( + prisma_client + ).table.find_many( + where=where_conditions, + skip=(startIndex - 1), + take=count, + order={"created_at": "desc"}, ) # Get total count for pagination - total_count = await prisma_client.db.litellm_usertable.count( + total_count = await UserRepository(prisma_client).table.count( where=where_conditions ) @@ -975,7 +985,7 @@ async def create_user( # Check if user already exists if user.userName: - existing_user = await prisma_client.db.litellm_usertable.find_unique( + existing_user = await UserRepository(prisma_client).table.find_unique( where={"user_id": user.userName} ) if existing_user: @@ -1094,7 +1104,7 @@ async def update_user( "metadata": safe_dumps(metadata), } - updated_user = await prisma_client.db.litellm_usertable.update( + updated_user = await UserRepository(prisma_client).table.update( where={"user_id": user_id}, data=update_data, ) @@ -1137,7 +1147,7 @@ async def delete_user( teams = [] if existing_user.teams: for team_id in existing_user.teams: - team = await prisma_client.db.litellm_teamtable.find_unique( + team = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) if team: @@ -1148,7 +1158,7 @@ async def delete_user( current_members = team.members or [] if user_id in current_members: new_members = [m for m in current_members if m != user_id] - await prisma_client.db.litellm_teamtable.update( + await TeamRepository(prisma_client).table.update( where={"team_id": team.team_id}, data={"members": new_members} ) @@ -1157,7 +1167,7 @@ async def delete_user( await _delete_rows_referencing_user(prisma_client, user_id=user_id) # Delete user - await prisma_client.db.litellm_usertable.delete(where={"user_id": user_id}) + await UserRepository(prisma_client).table.delete(where={"user_id": user_id}) return Response(status_code=204) except Exception as e: @@ -1413,7 +1423,7 @@ async def patch_user( update_data["metadata"] = safe_dumps(update_data["metadata"]) - updated_user = await prisma_client.db.litellm_usertable.update( + updated_user = await UserRepository(prisma_client).table.update( where={"user_id": user_id}, data=update_data, ) @@ -1465,7 +1475,7 @@ async def get_groups( where_conditions["team_alias"] = team_alias # Get teams from database - teams = await prisma_client.db.litellm_teamtable.find_many( + teams = await TeamRepository(prisma_client).table.find_many( where=where_conditions, skip=(startIndex - 1), take=count, @@ -1473,7 +1483,7 @@ async def get_groups( ) # Get total count for pagination - total_count = await prisma_client.db.litellm_teamtable.count( + total_count = await TeamRepository(prisma_client).table.count( where=where_conditions ) @@ -1561,7 +1571,7 @@ async def create_group( team_id = group.id or group.externalId or str(uuid.uuid4()) # Check if team already exists - existing_team = await prisma_client.db.litellm_teamtable.find_unique( + existing_team = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) @@ -1638,7 +1648,7 @@ async def update_group( } # Update team in database - updated_team = await prisma_client.db.litellm_teamtable.update( + updated_team = await TeamRepository(prisma_client).table.update( where={"team_id": group_id}, data=update_data, ) @@ -1683,19 +1693,19 @@ async def delete_group( # For each member, remove this team from their teams list for member_id in existing_team.members or []: - user = await prisma_client.db.litellm_usertable.find_unique( + user = await UserRepository(prisma_client).table.find_unique( where={"user_id": member_id} ) if user: current_teams = user.teams or [] if group_id in current_teams: new_teams = [t for t in current_teams if t != group_id] - await prisma_client.db.litellm_usertable.update( + await UserRepository(prisma_client).table.update( where={"user_id": member_id}, data={"teams": new_teams} ) # Delete team - await prisma_client.db.litellm_teamtable.delete(where={"team_id": group_id}) + await TeamRepository(prisma_client).table.delete(where={"team_id": group_id}) return Response(status_code=204) @@ -1748,7 +1758,7 @@ async def _process_group_patch_operations( detail={"error": "Invalid member: user ID cannot be empty."}, ) - user = await prisma_client.db.litellm_usertable.find_unique( + user = await UserRepository(prisma_client).table.find_unique( where={"user_id": member_id} ) if user: @@ -1805,7 +1815,7 @@ async def _apply_group_patch_updates( update_data["members"] = list(final_members) # Update team in database - updated_team = await prisma_client.db.litellm_teamtable.update( + updated_team = await TeamRepository(prisma_client).table.update( where={"team_id": group_id}, data=update_data, ) @@ -1877,7 +1887,7 @@ async def patch_group( # Refresh team data from database to get the latest state after concurrent updates # This prevents race conditions when multiple PATCH requests come in simultaneously - refreshed_team = await prisma_client.db.litellm_teamtable.find_unique( + refreshed_team = await TeamRepository(prisma_client).table.find_unique( where={"team_id": group_id} ) if refreshed_team: @@ -1894,7 +1904,7 @@ async def patch_group( await _handle_group_membership_changes(group_id, current_members, final_members) # Refresh team one more time to get final state after membership changes - final_team = await prisma_client.db.litellm_teamtable.find_unique( + final_team = await TeamRepository(prisma_client).table.find_unique( where={"team_id": group_id} ) if final_team: diff --git a/litellm/proxy/management_endpoints/tag_management_endpoints.py b/litellm/proxy/management_endpoints/tag_management_endpoints.py index 49d9b67a28..f0bb8bdb5f 100644 --- a/litellm/proxy/management_endpoints/tag_management_endpoints.py +++ b/litellm/proxy/management_endpoints/tag_management_endpoints.py @@ -25,6 +25,14 @@ from litellm.proxy.management_endpoints.common_daily_activity import ( get_daily_activity, ) from litellm.proxy.management_helpers.utils import handle_budget_for_entity +from litellm.repositories.model_repository import ModelRepository +from litellm.repositories.table_repositories import ( + DailyTagSpendRepository, + TagRepository, +) +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.types.tag_management import ( TagConfig, TagDeleteRequest, @@ -56,7 +64,7 @@ async def _get_internal_user_api_keys( if user_id is None: return sorted(user_api_keys) - key_records = await prisma_client.db.litellm_verificationtoken.find_many( + key_records = await VerificationTokenRepository(prisma_client).table.find_many( where={"user_id": user_id}, select={"token": True}, ) @@ -109,7 +117,7 @@ async def _get_tag_daily_activity_api_key_filter( async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]: """Helper function to get model names from model IDs""" try: - models = await prisma_client.db.litellm_proxymodeltable.find_many( + models = await ModelRepository(prisma_client).table.find_many( where={"model_id": {"in": model_ids}} ) return {model.model_id: model.model_name for model in models} @@ -189,7 +197,7 @@ async def new_tag( ) try: # Check if tag already exists - existing_tag = await prisma_client.db.litellm_tagtable.find_unique( + existing_tag = await TagRepository(prisma_client).table.find_unique( where={"tag_name": tag.name} ) if existing_tag is not None: @@ -210,7 +218,7 @@ async def new_tag( model_info = await _get_model_names(prisma_client, tag.models or []) # Create new tag in database - new_tag_record = await prisma_client.db.litellm_tagtable.create( + new_tag_record = await TagRepository(prisma_client).table.create( data={ "tag_name": tag.name, "description": tag.description, @@ -267,7 +275,7 @@ async def _add_tag_to_deployment(deployment: "Deployment", tag: str): try: # Get current model from database to preserve encrypted fields - db_model = await prisma_client.db.litellm_proxymodeltable.find_unique( + db_model = await ModelRepository(prisma_client).table.find_unique( where={"model_id": deployment.model_info.id} ) @@ -292,7 +300,7 @@ async def _add_tag_to_deployment(deployment: "Deployment", tag: str): existing_params["tags"].append(tag) # Update database with modified params (keeps encrypted fields encrypted) - await prisma_client.db.litellm_proxymodeltable.update( + await ModelRepository(prisma_client).table.update( where={"model_id": deployment.model_info.id}, data={"litellm_params": json.dumps(existing_params)}, ) @@ -335,7 +343,7 @@ async def update_tag( try: # Check if tag exists - existing_tag = await prisma_client.db.litellm_tagtable.find_unique( + existing_tag = await TagRepository(prisma_client).table.find_unique( where={"tag_name": tag.name} ) if existing_tag is None: @@ -367,7 +375,7 @@ async def update_tag( update_data["budget_id"] = budget_id # Update tag in database - updated_tag_record = await prisma_client.db.litellm_tagtable.update( + updated_tag_record = await TagRepository(prisma_client).table.update( where={"tag_name": tag.name}, data=update_data, ) @@ -414,7 +422,7 @@ async def info_tag( try: # Query tags from database with budget info - tag_records = await prisma_client.db.litellm_tagtable.find_many( + tag_records = await TagRepository(prisma_client).table.find_many( where={"tag_name": {"in": data.names}}, include={"litellm_budget_table": True}, ) @@ -535,7 +543,7 @@ async def list_tags( if start_date is not None and end_date is not None: dynamic_tag_where["date"] = {"gte": start_date, "lte": end_date} - dynamic_tag_rows = await prisma_client.db.litellm_dailytagspend.group_by( + dynamic_tag_rows = await DailyTagSpendRepository(prisma_client).table.group_by( by=["tag"], where=dynamic_tag_where, min={"created_at": True}, @@ -551,7 +559,7 @@ async def list_tags( ) ## QUERY STORED TAGS ## - tag_records = await prisma_client.db.litellm_tagtable.find_many( + tag_records = await TagRepository(prisma_client).table.find_many( where=stored_tag_where, include={"litellm_budget_table": True}, ) @@ -626,14 +634,14 @@ async def delete_tag( try: # Check if tag exists - existing_tag = await prisma_client.db.litellm_tagtable.find_unique( + existing_tag = await TagRepository(prisma_client).table.find_unique( where={"tag_name": data.name} ) if existing_tag is None: raise HTTPException(status_code=404, detail=f"Tag {data.name} not found") # Delete tag from database - await prisma_client.db.litellm_tagtable.delete(where={"tag_name": data.name}) + await TagRepository(prisma_client).table.delete(where={"tag_name": data.name}) return {"message": f"Tag {data.name} deleted successfully"} except Exception as e: diff --git a/litellm/proxy/management_endpoints/team_callback_endpoints.py b/litellm/proxy/management_endpoints/team_callback_endpoints.py index 63b56425b0..0c11507697 100644 --- a/litellm/proxy/management_endpoints/team_callback_endpoints.py +++ b/litellm/proxy/management_endpoints/team_callback_endpoints.py @@ -30,6 +30,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars from litellm.proxy.management_endpoints.team_endpoints import _verify_team_access from litellm.proxy.management_helpers.utils import management_endpoint_wrapper +from litellm.repositories.team_repository import TeamRepository router = APIRouter() @@ -249,7 +250,7 @@ async def add_team_callbacks( team_metadata = encrypt_callback_vars(team_metadata) team_metadata_json = json.dumps(team_metadata) # update team_metadata - new_team_row = await prisma_client.db.litellm_teamtable.update( + new_team_row = await TeamRepository(prisma_client).table.update( where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore ) @@ -353,7 +354,7 @@ async def disable_team_logging( team_metadata_json = json.dumps(team_metadata) # Update team in database - updated_team = await prisma_client.db.litellm_teamtable.update( + updated_team = await TeamRepository(prisma_client).table.update( where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore ) diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index a3ad7a9ea8..7a784ee462 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -10,8 +10,8 @@ All /team management endpoints """ import asyncio -import math import json +import math import traceback from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Tuple, Union, cast @@ -102,6 +102,20 @@ from litellm.proxy.management_helpers.utils import ( management_endpoint_wrapper, ) from litellm.proxy.utils import PrismaClient, handle_exception_on_proxy +from litellm.repositories.budget_repository import BudgetRepository +from litellm.repositories.organization_repository import OrganizationRepository +from litellm.repositories.table_repositories import ( + AccessGroupRepository, + DeletedTeamRepository, + ModelTableRepository, + OrganizationMembershipRepository, + TeamMembershipRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.router import Router from litellm.types.proxy.management_endpoints.common_daily_activity import ( SpendAnalyticsPaginatedResponse, @@ -365,9 +379,9 @@ class TeamMemberBudgetHandler: return # Batch-fetch existing memberships for this team (avoids N+1 queries) - existing_memberships = await prisma_client.db.litellm_teammembership.find_many( - where={"team_id": team_id} - ) + existing_memberships = await TeamMembershipRepository( + prisma_client + ).table.find_many(where={"team_id": team_id}) existing_user_ids = {m.user_id for m in existing_memberships} # Identify members with no existing membership row. @@ -386,7 +400,7 @@ class TeamMemberBudgetHandler: ) if missing: - await prisma_client.db.litellm_teammembership.create_many( + await TeamMembershipRepository(prisma_client).table.create_many( data=missing, skip_duplicates=True, # safety net against concurrent races ) @@ -400,7 +414,7 @@ class TeamMemberBudgetHandler: # Heal existing membership rows that predate the team_member_budget # configuration: populate budget_id where it is currently NULL. # Rows with an explicit budget_id (per-member override) are left alone. - updated = await prisma_client.db.litellm_teammembership.update_many( + updated = await TeamMembershipRepository(prisma_client).table.update_many( where={"team_id": team_id, "budget_id": None}, data={"budget_id": team_member_budget_id}, ) @@ -456,7 +470,7 @@ async def get_all_team_memberships( # else: # where_obj = {"user_id": str(user_id), "team_id": {"in": team_id}} - team_memberships = await prisma_client.db.litellm_teammembership.find_many( + team_memberships = await TeamMembershipRepository(prisma_client).table.find_many( where=where_obj, include={"litellm_budget_table": True}, ) @@ -739,7 +753,7 @@ async def _check_org_team_limits( # calculate allocated tpm/rpm limit # check if specified tpm/rpm limit is greater than allocated tpm/rpm limit - teams = await prisma_client.db.litellm_teamtable.find_many( + teams = await TeamRepository(prisma_client).table.find_many( where={"organization_id": org_table.organization_id}, ) @@ -931,6 +945,9 @@ async def new_team( # noqa: PLR0915 ``` """ try: + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) from litellm.proxy.proxy_server import ( _license_check, create_audit_log_for_update, @@ -938,9 +955,6 @@ async def new_team( # noqa: PLR0915 prisma_client, user_api_key_cache, ) - from litellm.proxy.management_helpers.audit_logs import ( - get_audit_log_changed_by, - ) if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No db connected"}) @@ -986,7 +1000,7 @@ async def new_team( # noqa: PLR0915 ) # Check if license is over limit - total_teams = await prisma_client.db.litellm_teamtable.count() + total_teams = await TeamRepository(prisma_client).table.count() if total_teams and _license_check.is_team_count_over_limit( team_count=total_teams ): @@ -1092,7 +1106,7 @@ async def new_team( # noqa: PLR0915 created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, ) - model_dict = await prisma_client.db.litellm_modeltable.create( + model_dict = await ModelTableRepository(prisma_client).table.create( {**litellm_modeltable.json(exclude_none=True)} # type: ignore ) # type: ignore @@ -1195,7 +1209,7 @@ async def new_team( # noqa: PLR0915 db_data=complete_team_data_dict ) - team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create( + team_row: LiteLLM_TeamTable = await TeamRepository(prisma_client).table.create( data=complete_team_data_dict, include={"litellm_model_table": True}, # type: ignore ) @@ -1315,11 +1329,11 @@ async def _update_model_table( updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, ) if model_id is None: - model_dict = await prisma_client.db.litellm_modeltable.create( + model_dict = await ModelTableRepository(prisma_client).table.create( data={**litellm_modeltable.json(exclude_none=True)} # type: ignore ) else: - model_dict = await prisma_client.db.litellm_modeltable.upsert( + model_dict = await ModelTableRepository(prisma_client).table.upsert( where={"id": model_id}, data={ "update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore @@ -1400,7 +1414,7 @@ async def fetch_and_validate_organization( status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} ) - organization_row = await prisma_client.db.litellm_organizationtable.find_unique( + organization_row = await OrganizationRepository(prisma_client).table.find_unique( where={"organization_id": organization_id}, include={"litellm_budget_table": True, "members": True, "teams": True}, ) @@ -1669,7 +1683,7 @@ async def update_team( # noqa: PLR0915 }, ) - existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + existing_team_row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": data.team_id} ) @@ -1738,7 +1752,9 @@ async def update_team( # noqa: PLR0915 ): # Is the caller org_admin of the destination org? caller_memberships = ( - await prisma_client.db.litellm_organizationmembership.find_many( + await OrganizationMembershipRepository( + prisma_client + ).table.find_many( where={ "user_id": user_api_key_dict.user_id, "organization_id": data.organization_id, @@ -1885,18 +1901,18 @@ async def update_team( # noqa: PLR0915 updated_kv["router_settings"] = safe_dumps(updated_kv["router_settings"]) updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv) - team_row: Optional[LiteLLM_TeamTable] = ( - await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, - data=updated_kv, - # `object_permission` is included so `_refresh_cached_team` - # doesn't write a cached team with the relation nulled out — - # see team_model_add for the full rationale. - include={ - "litellm_model_table": True, - "object_permission": True, - }, # type: ignore - ) + team_row: Optional[LiteLLM_TeamTable] = await TeamRepository( + prisma_client + ).table.update( + where={"team_id": data.team_id}, + data=updated_kv, + # `object_permission` is included so `_refresh_cached_team` + # doesn't write a cached team with the relation nulled out — + # see team_model_add for the full rationale. + include={ + "litellm_model_table": True, + "object_permission": True, + }, # type: ignore ) if team_row is None or team_row.team_id is None: @@ -2306,7 +2322,7 @@ async def _add_team_members_to_team( # ADD MEMBER TO TEAM _db_team_members = [m.model_dump() for m in complete_team_data.members_with_roles] - updated_team = await prisma_client.db.litellm_teamtable.update( + updated_team = await TeamRepository(prisma_client).table.update( where={"team_id": data.team_id}, data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore ) @@ -2377,7 +2393,7 @@ async def _validate_and_populate_member_user_info( # Case 2: Only user_email provided - populate user_id from DB if member.user_email is not None and member.user_id is None: - user_by_email = await prisma_client.db.litellm_usertable.find_first( + user_by_email = await UserRepository(prisma_client).table.find_first( where={"user_email": {"equals": member.user_email, "mode": "insensitive"}} ) @@ -2410,7 +2426,7 @@ async def _validate_and_populate_member_user_info( # Case 3: Only user_id provided - populate user_email from DB if user exists if member.user_id is not None and member.user_email is None: - user_by_id = await prisma_client.db.litellm_usertable.find_unique( + user_by_id = await UserRepository(prisma_client).table.find_unique( where={"user_id": member.user_id} ) @@ -2608,7 +2624,7 @@ async def team_member_delete( detail={"error": "Either user_id or user_email needs to be passed in"}, ) - _existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + _existing_team_row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": data.team_id} ) @@ -2652,7 +2668,7 @@ async def team_member_delete( _db_new_team_members: List[dict] = [m.model_dump() for m in new_team_members] - _ = await prisma_client.db.litellm_teamtable.update( + _ = await TeamRepository(prisma_client).table.update( where={ "team_id": data.team_id, }, @@ -2666,7 +2682,7 @@ async def team_member_delete( key_val["user_id"] = data.user_id elif data.user_email is not None: key_val["user_email"] = data.user_email - existing_user_rows = await prisma_client.db.litellm_usertable.find_many( + existing_user_rows = await UserRepository(prisma_client).table.find_many( where=key_val # type: ignore ) @@ -2678,7 +2694,7 @@ async def team_member_delete( if data.team_id in existing_user.teams: team_list = existing_user.teams team_list.remove(data.team_id) - await prisma_client.db.litellm_usertable.update( + await UserRepository(prisma_client).table.update( where={ "user_id": existing_user.user_id, }, @@ -2695,7 +2711,7 @@ async def team_member_delete( user_ids_to_delete.add(existing_user.user_id) for _uid in user_ids_to_delete: - await prisma_client.db.litellm_teammembership.delete_many( + await TeamMembershipRepository(prisma_client).table.delete_many( where={"team_id": data.team_id, "user_id": _uid} ) @@ -2706,13 +2722,13 @@ async def team_member_delete( ) # Fetch keys before deletion to persist them - keys_to_delete: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={ - "user_id": {"in": list(user_ids_to_delete)}, - "team_id": data.team_id, - } - ) + keys_to_delete: List[ + LiteLLM_VerificationToken + ] = await VerificationTokenRepository(prisma_client).table.find_many( + where={ + "user_id": {"in": list(user_ids_to_delete)}, + "team_id": data.team_id, + } ) if keys_to_delete: @@ -2723,7 +2739,7 @@ async def team_member_delete( litellm_changed_by=None, ) - await prisma_client.db.litellm_verificationtoken.delete_many( + await VerificationTokenRepository(prisma_client).table.delete_many( where={ "user_id": {"in": list(user_ids_to_delete)}, "team_id": data.team_id, @@ -2818,7 +2834,7 @@ async def team_member_update( _validate_budget_duration(data.budget_duration) - _existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + _existing_team_row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": data.team_id} ) @@ -2921,7 +2937,7 @@ async def team_member_update( team_table.members_with_roles = team_members _db_team_members: List[dict] = [m.model_dump() for m in team_members] - await prisma_client.db.litellm_teamtable.update( + await TeamRepository(prisma_client).table.update( where={"team_id": data.team_id}, data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore ) @@ -3052,7 +3068,7 @@ async def bulk_team_member_add( }, ) # get all users from the database - all_users_in_db = await prisma_client.db.litellm_usertable.find_many( + all_users_in_db = await UserRepository(prisma_client).table.find_many( order={"created_at": "desc"} ) data.members = [ @@ -3153,14 +3169,14 @@ async def delete_team( }' ``` """ + from litellm.proxy.management_helpers.audit_logs import ( + get_audit_log_changed_by, + ) from litellm.proxy.proxy_server import ( create_audit_log_for_update, litellm_proxy_admin_name, prisma_client, ) - from litellm.proxy.management_helpers.audit_logs import ( - get_audit_log_changed_by, - ) if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No db connected"}) @@ -3172,11 +3188,9 @@ async def delete_team( team_rows: List[LiteLLM_TeamTable] = [] for team_id in data.team_ids: try: - team_row_base: Optional[BaseModel] = ( - await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id} - ) - ) + team_row_base: Optional[BaseModel] = await TeamRepository( + prisma_client + ).table.find_unique(where={"team_id": team_id}) if team_row_base is None: raise Exception except Exception: @@ -3243,11 +3257,9 @@ async def delete_team( _persist_deleted_verification_tokens, ) - keys_to_delete: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={"team_id": {"in": data.team_ids}} - ) - ) + keys_to_delete: List[LiteLLM_VerificationToken] = await VerificationTokenRepository( + prisma_client + ).table.find_many(where={"team_id": {"in": data.team_ids}}) if keys_to_delete: await _persist_deleted_verification_tokens( @@ -3338,7 +3350,7 @@ async def _save_deleted_team_records( """Save deleted team records to the database.""" if not records: return - await prisma_client.db.litellm_deletedteamtable.create_many(data=records) + await DeletedTeamRepository(prisma_client).table.create_many(data=records) async def _persist_deleted_team_records( @@ -3420,7 +3432,7 @@ async def _add_team_member_budget_table( team_info_response_object: TeamInfoResponseObjectTeamTable, ) -> TeamInfoResponseObjectTeamTable: try: - team_budget = await prisma_client.db.litellm_budgettable.find_unique( + team_budget = await BudgetRepository(prisma_client).table.find_unique( where={"budget_id": team_member_budget_id} ) team_info_response_object.team_member_budget_table = team_budget @@ -3489,11 +3501,11 @@ async def team_info( ) try: - team_info: Optional[BaseModel] = ( - await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id}, - include={"object_permission": True}, - ) + team_info: Optional[BaseModel] = await TeamRepository( + prisma_client + ).table.find_unique( + where={"team_id": team_id}, + include={"object_permission": True}, ) if team_info is None: raise Exception @@ -3749,7 +3761,7 @@ async def block_team( if prisma_client is None: raise Exception("No DB Connected.") - existing_team = await prisma_client.db.litellm_teamtable.find_unique( + existing_team = await TeamRepository(prisma_client).table.find_unique( where={"team_id": data.team_id} ) if existing_team is None: @@ -3764,7 +3776,7 @@ async def block_team( user_api_key_dict=user_api_key_dict, ) - record = await prisma_client.db.litellm_teamtable.update( + record = await TeamRepository(prisma_client).table.update( where={"team_id": data.team_id}, data={"blocked": True} # type: ignore ) @@ -3801,7 +3813,7 @@ async def unblock_team( if prisma_client is None: raise Exception("No DB Connected.") - existing_team = await prisma_client.db.litellm_teamtable.find_unique( + existing_team = await TeamRepository(prisma_client).table.find_unique( where={"team_id": data.team_id} ) if existing_team is None: @@ -3816,7 +3828,7 @@ async def unblock_team( user_api_key_dict=user_api_key_dict, ) - record = await prisma_client.db.litellm_teamtable.update( + record = await TeamRepository(prisma_client).table.update( where={"team_id": data.team_id}, data={"blocked": False} # type: ignore ) @@ -3849,7 +3861,7 @@ async def list_available_teams( return [] # filter out teams that the user is already a member of - user_info = await prisma_client.db.litellm_usertable.find_unique( + user_info = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_api_key_dict.user_id} ) if user_info is None: @@ -3863,7 +3875,7 @@ async def list_available_teams( team for team in available_teams if team not in user_info_correct_type.teams ] - available_teams_db = await prisma_client.db.litellm_teamtable.find_many( + available_teams_db = await TeamRepository(prisma_client).table.find_many( where={"team_id": {"in": available_teams}} ) @@ -4009,7 +4021,7 @@ async def _batch_resolve_access_group_resources( return {} unique_ids = list(set(all_access_group_ids)) - rows = await _prisma_client.db.litellm_accessgrouptable.find_many( + rows = await AccessGroupRepository(_prisma_client).table.find_many( where={"access_group_id": {"in": unique_ids}}, ) @@ -4074,7 +4086,7 @@ async def _get_keys_count_by_team( if not page_team_ids: return {} - grouped = await prisma_client.db.litellm_verificationtoken.group_by( + grouped = await VerificationTokenRepository(prisma_client).table.group_by( by=["team_id"], where={"team_id": {"in": page_team_ids}}, count={"team_id": True}, @@ -4288,25 +4300,25 @@ async def list_team_v2( # Get teams with pagination if use_deleted_table: - teams = await prisma_client.db.litellm_deletedteamtable.find_many( + teams = await DeletedTeamRepository(prisma_client).table.find_many( where=where_conditions, skip=skip, take=page_size, order=order_by if order_by else {"created_at": "desc"}, # Default sort ) # Get total count for pagination - total_count = await prisma_client.db.litellm_deletedteamtable.count( + total_count = await DeletedTeamRepository(prisma_client).table.count( where=where_conditions ) else: - teams = await prisma_client.db.litellm_teamtable.find_many( + teams = await TeamRepository(prisma_client).table.find_many( where=where_conditions, skip=skip, take=page_size, order=order_by if order_by else {"created_at": "desc"}, # Default sort ) # Get total count for pagination - total_count = await prisma_client.db.litellm_teamtable.count( + total_count = await TeamRepository(prisma_client).table.count( where=where_conditions ) @@ -4412,7 +4424,7 @@ async def _authorize_and_filter_teams( if allowed_org_ids is not None: # Org admin: query DB for teams in their orgs - org_teams = await prisma_client.db.litellm_teamtable.find_many( + org_teams = await TeamRepository(prisma_client).table.find_many( where={"organization_id": {"in": allowed_org_ids}}, include={"litellm_model_table": True}, ) @@ -4427,7 +4439,7 @@ async def _authorize_and_filter_teams( ] elif user_id: # Regular user: fetch all and filter by membership (Prisma can't filter JSON arrays) - response = await prisma_client.db.litellm_teamtable.find_many( + response = await TeamRepository(prisma_client).table.find_many( include={"litellm_model_table": True} ) return [ @@ -4439,7 +4451,7 @@ async def _authorize_and_filter_teams( else: # Proxy admin: all teams return list( - await prisma_client.db.litellm_teamtable.find_many( + await TeamRepository(prisma_client).table.find_many( include={"litellm_model_table": True} ) ) @@ -4500,7 +4512,7 @@ async def list_team( _team_memberships.append(tm) # add all keys that belong to the team - keys = await prisma_client.db.litellm_verificationtoken.find_many( + keys = await VerificationTokenRepository(prisma_client).table.find_many( where={"team_id": team.team_id} ) @@ -4556,10 +4568,10 @@ async def get_paginated_teams( # Calculate skip for pagination skip = (page - 1) * page_size # Get total count - total_count = await prisma_client.db.litellm_teamtable.count() + total_count = await TeamRepository(prisma_client).table.count() # Get paginated teams - teams = await prisma_client.db.litellm_teamtable.find_many( + teams = await TeamRepository(prisma_client).table.find_many( skip=skip, take=page_size, order={"team_alias": "asc"} # Sort by team_alias ) return teams, total_count @@ -4632,7 +4644,7 @@ async def ui_view_teams( } # Query users with pagination and filters - teams = await prisma_client.db.litellm_teamtable.find_many( + teams = await TeamRepository(prisma_client).table.find_many( where=where_conditions, skip=skip, take=page_size, @@ -4704,7 +4716,7 @@ async def team_model_add( raise HTTPException(status_code=500, detail={"error": "No db connected"}) # Get existing team - team_row = await prisma_client.db.litellm_teamtable.find_unique( + team_row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": data.team_id} ) @@ -4737,7 +4749,7 @@ async def team_model_add( # null them out — see object_permission_utils.validate_key_search_tools_against_team # and the MCP/agent authz paths, which treat a missing object_permission # as "no team-level restriction". - updated_team = await prisma_client.db.litellm_teamtable.update( + updated_team = await TeamRepository(prisma_client).table.update( where={"team_id": data.team_id}, data={"models": updated_models}, include={"object_permission": True}, # type: ignore @@ -4791,7 +4803,7 @@ async def team_model_delete( raise HTTPException(status_code=500, detail={"error": "No db connected"}) # Get existing team - team_row = await prisma_client.db.litellm_teamtable.find_unique( + team_row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": data.team_id} ) @@ -4825,7 +4837,7 @@ async def team_model_delete( updated_models = [m for m in current_models if m not in data.models] # Update team. See team_model_add for the rationale on `include`. - updated_team = await prisma_client.db.litellm_teamtable.update( + updated_team = await TeamRepository(prisma_client).table.update( where={"team_id": data.team_id}, data={"models": updated_models}, include={"object_permission": True}, # type: ignore @@ -4972,7 +4984,7 @@ async def update_team_member_permissions( }, ) # Update the team member permissions - updated_team = await prisma_client.db.litellm_teamtable.update( + updated_team = await TeamRepository(prisma_client).table.update( where={"team_id": data.team_id}, data={"team_member_permissions": data.team_member_permissions}, ) @@ -5076,7 +5088,7 @@ async def _append_permissions_to_specific_teams( prisma_client, team_ids: List[str], permissions_to_add: set ) -> int: """Fetch specific teams by ID and append permissions.""" - teams = await prisma_client.db.litellm_teamtable.find_many( + teams = await TeamRepository(prisma_client).table.find_many( where={"team_id": {"in": team_ids}}, ) @@ -5108,7 +5120,7 @@ async def _append_permissions_to_all_teams( find_args["cursor"] = {"team_id": cursor} find_args["skip"] = 1 - teams = await prisma_client.db.litellm_teamtable.find_many(**find_args) + teams = await TeamRepository(prisma_client).table.find_many(**find_args) if not teams: break @@ -5214,7 +5226,7 @@ async def get_team_daily_activity( where_condition = {} if team_ids_list: where_condition["team_id"] = {"in": list(team_ids_list)} - team_aliases = await prisma_client.db.litellm_teamtable.find_many( + team_aliases = await TeamRepository(prisma_client).table.find_many( where=where_condition ) team_alias_metadata = { @@ -5251,9 +5263,9 @@ async def get_team_daily_activity( # If user does not have full team view, filter by their API keys if not has_full_team_view: # Get all API keys for this user - user_keys = await prisma_client.db.litellm_verificationtoken.find_many( - where={"user_id": user_api_key_dict.user_id} - ) + user_keys = await VerificationTokenRepository( + prisma_client + ).table.find_many(where={"user_id": user_api_key_dict.user_id}) user_api_keys = [key.token for key in user_keys if key.token] # If user has no API keys, return empty result if not user_api_keys: diff --git a/litellm/proxy/management_endpoints/tool_management_endpoints.py b/litellm/proxy/management_endpoints/tool_management_endpoints.py index 19ca2c9f6b..a9b57db8a6 100644 --- a/litellm/proxy/management_endpoints/tool_management_endpoints.py +++ b/litellm/proxy/management_endpoints/tool_management_endpoints.py @@ -21,6 +21,15 @@ if TYPE_CHECKING: from litellm._logging import verbose_proxy_logger from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.repositories.object_permission_repository import ObjectPermissionRepository +from litellm.repositories.table_repositories import ( + SpendLogsRepository, + SpendLogToolIndexRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.types.tool_management import ( LiteLLM_ToolTableRow, ToolDetailResponse, @@ -256,8 +265,10 @@ async def get_tool_usage_logs( if end_time_filter is not None: where["start_time"]["lte"] = end_time_filter - total = await prisma_client.db.litellm_spendlogtoolindex.count(where=where) - index_rows = await prisma_client.db.litellm_spendlogtoolindex.find_many( + total = await SpendLogToolIndexRepository(prisma_client).table.count( + where=where + ) + index_rows = await SpendLogToolIndexRepository(prisma_client).table.find_many( where=where, order={"start_time": "desc"}, skip=(page - 1) * page_size, @@ -269,7 +280,7 @@ async def get_tool_usage_logs( logs=[], total=total, page=page, page_size=page_size ) - spend_logs = await prisma_client.db.litellm_spendlogs.find_many( + spend_logs = await SpendLogsRepository(prisma_client).table.find_many( where={"request_id": {"in": request_ids}} ) log_by_id = {s.request_id: s for s in spend_logs} @@ -348,7 +359,7 @@ async def _resolve_key_hash_to_object_permission_id( hashed = key_hash if "sk-" not in (key_hash or "") else hash_token(key_hash) if not hashed: return None - row = await prisma_client.db.litellm_verificationtoken.find_unique( + row = await VerificationTokenRepository(prisma_client).table.find_unique( where={"token": hashed} ) if row is None: @@ -357,18 +368,18 @@ async def _resolve_key_hash_to_object_permission_id( if op_id: return op_id new_id = str(uuid.uuid4()) - await prisma_client.db.litellm_objectpermissiontable.create( + await ObjectPermissionRepository(prisma_client).table.create( data={"object_permission_id": new_id, "blocked_tools": []} ) - updated_count = await prisma_client.db.litellm_verificationtoken.update_many( + updated_count = await VerificationTokenRepository(prisma_client).table.update_many( where={"token": hashed, "object_permission_id": None}, data={"object_permission_id": new_id}, ) if updated_count == 0: - await prisma_client.db.litellm_objectpermissiontable.delete( + await ObjectPermissionRepository(prisma_client).table.delete( where={"object_permission_id": new_id} ) - row = await prisma_client.db.litellm_verificationtoken.find_unique( + row = await VerificationTokenRepository(prisma_client).table.find_unique( where={"token": hashed} ) return getattr(row, "object_permission_id", None) if row else None @@ -383,7 +394,7 @@ async def _resolve_team_id_to_object_permission_id( if not team_id or not team_id.strip(): return None team_id_clean = team_id.strip() - row = await prisma_client.db.litellm_teamtable.find_unique( + row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id_clean}, select={"object_permission_id": True}, ) @@ -393,18 +404,18 @@ async def _resolve_team_id_to_object_permission_id( if op_id: return op_id new_id = str(uuid.uuid4()) - await prisma_client.db.litellm_objectpermissiontable.create( + await ObjectPermissionRepository(prisma_client).table.create( data={"object_permission_id": new_id, "blocked_tools": []} ) - updated_count = await prisma_client.db.litellm_teamtable.update_many( + updated_count = await TeamRepository(prisma_client).table.update_many( where={"team_id": team_id_clean, "object_permission_id": None}, data={"object_permission_id": new_id}, ) if updated_count == 0: - await prisma_client.db.litellm_objectpermissiontable.delete( + await ObjectPermissionRepository(prisma_client).table.delete( where={"object_permission_id": new_id} ) - row = await prisma_client.db.litellm_teamtable.find_unique( + row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id_clean}, select={"object_permission_id": True}, ) diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index d6082899c0..a957061685 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -15,8 +15,8 @@ import inspect import os import re import secrets -from html import escape from copy import deepcopy +from html import escape from typing import ( TYPE_CHECKING, Any, @@ -39,9 +39,9 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request, status from fastapi.responses import RedirectResponse import litellm -from litellm.caching.dual_cache import DualCache from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid +from litellm.caching.dual_cache import DualCache from litellm.constants import ( CLI_SSO_CLAIM_MAP, CLI_SSO_CLAIM_MAX_SCALAR_LENGTH, @@ -77,7 +77,6 @@ from litellm.proxy._types import ( UserAPIKeyAuth, ) from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken, get_user_object -from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.auth.auth_utils import ( _get_request_ip_address, _has_user_setup_sso, @@ -92,6 +91,7 @@ from litellm.proxy.common_utils.html_forms.jwt_display_template import ( jwt_display_template, ) from litellm.proxy.common_utils.html_forms.ui_login import html_form +from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.management_endpoints.internal_user_endpoints import new_user from litellm.proxy.management_endpoints.sso import CustomMicrosoftSSO from litellm.proxy.management_endpoints.sso_helper_utils import ( @@ -110,6 +110,9 @@ from litellm.proxy.utils import ( get_custom_url, get_server_root_path, ) +from litellm.repositories.table_repositories import SSOConfigRepository +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository from litellm.secret_managers.main import get_secret_bool, str_to_bool from litellm.types.proxy.management_endpoints.ui_sso import * # noqa: F403, F401 from litellm.types.proxy.management_endpoints.ui_sso import ( @@ -438,7 +441,7 @@ async def _persist_cli_sso_user_metadata( return try: - user_row = await prisma_client.db.litellm_usertable.find_unique( + user_row = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_id} ) existing_metadata: Dict[str, Any] = {} @@ -451,7 +454,7 @@ async def _persist_cli_sso_user_metadata( existing_metadata=existing_metadata, attribution_metadata=attribution_metadata, ) - await prisma_client.db.litellm_usertable.update_many( + await UserRepository(prisma_client).table.update_many( where={"user_id": user_id}, data={"metadata": merged_metadata}, ) @@ -859,7 +862,7 @@ async def google_login( if premium_user is not True: # Check if under 'free SSO user' limit if prisma_client is not None: - total_users = await prisma_client.db.litellm_usertable.count() + total_users = await UserRepository(prisma_client).table.count() if total_users and total_users > 5: raise ProxyException( message="You must be a LiteLLM Enterprise user to use SSO for more than 5 users. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://enterprise.litellm.ai/demo You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this", @@ -1150,7 +1153,7 @@ async def _setup_team_mappings() -> Optional["TeamMappings"]: "Prisma client is None, connect a database to your proxy" ) - sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique( + sso_db_record = await SSOConfigRepository(prisma_client).table.find_unique( where={"id": "sso_config"} ) @@ -1188,7 +1191,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: "Prisma client is None, connect a database to your proxy" ) - sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique( + sso_db_record = await SSOConfigRepository(prisma_client).table.find_unique( where={"id": "sso_config"} ) @@ -1755,7 +1758,7 @@ async def _sync_user_role_from_jwt_role_map( # Update existing DB record if role differs if user_info is not None and user_info.user_role != mapped_role.value: - await prisma_client.db.litellm_usertable.update( + await UserRepository(prisma_client).table.update( where={"user_id": user_info.user_id}, data={"user_role": mapped_role.value}, ) @@ -1819,7 +1822,7 @@ async def check_and_update_if_proxy_admin_id( return user_role if prisma_client: - await prisma_client.db.litellm_usertable.update( + await UserRepository(prisma_client).table.update( where={"user_id": user_id}, data={"user_role": LitellmUserRoles.PROXY_ADMIN.value}, ) @@ -1976,7 +1979,7 @@ async def _fetch_cli_sso_team_details( team_details: List[Dict[str, Any]] = [] try: if teams: - prisma_teams = await prisma_client.db.litellm_teamtable.find_many( + prisma_teams = await TeamRepository(prisma_client).table.find_many( where={"team_id": {"in": teams}} ) for team_row in prisma_teams: @@ -2884,7 +2887,7 @@ class SSOAuthenticationHandler: user_id=user_id, ) - await prisma_client.db.litellm_usertable.update_many( + await UserRepository(prisma_client).table.update_many( where={"user_id": user_id}, data=update_data ) else: @@ -2986,7 +2989,7 @@ class SSOAuthenticationHandler: code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) try: - team_obj = await prisma_client.db.litellm_teamtable.find_first( + team_obj = await TeamRepository(prisma_client).table.find_first( where={"team_id": litellm_team_id} ) verbose_proxy_logger.debug(f"Team object: {team_obj}") diff --git a/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py b/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py index ebd276fbee..661487577c 100644 --- a/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py +++ b/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py @@ -19,6 +19,11 @@ from pydantic import BaseModel from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.repositories.table_repositories import DailyTagSpendRepository +from litellm.repositories.user_repository import UserRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) # Constants for analytics periods MAX_DAYS = 7 # Number of days to show in DAU analytics @@ -676,7 +681,7 @@ async def get_per_user_analytics( where_clause["tag"] = {"contains": tag_filter} # Get all tag records in the date range with optional tag filtering - tag_records = await prisma_client.db.litellm_dailytagspend.find_many( + tag_records = await DailyTagSpendRepository(prisma_client).table.find_many( where=where_clause ) @@ -693,9 +698,9 @@ async def get_per_user_analytics( ) # Lookup user_id for each api_key - api_key_records = await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": list(api_keys)}} - ) + api_key_records = await VerificationTokenRepository( + prisma_client + ).table.find_many(where={"token": {"in": list(api_keys)}}) # Create mapping from api_key to user_id api_key_to_user_id = { @@ -704,7 +709,7 @@ async def get_per_user_analytics( # Get user emails for the user_ids user_ids = list(set(api_key_to_user_id.values())) - user_records = await prisma_client.db.litellm_usertable.find_many( + user_records = await UserRepository(prisma_client).table.find_many( where={"user_id": {"in": user_ids}} ) diff --git a/litellm/proxy/management_endpoints/workflow_management_endpoints.py b/litellm/proxy/management_endpoints/workflow_management_endpoints.py index a19af4dd48..57cc0dc674 100644 --- a/litellm/proxy/management_endpoints/workflow_management_endpoints.py +++ b/litellm/proxy/management_endpoints/workflow_management_endpoints.py @@ -27,6 +27,11 @@ from pydantic import BaseModel from litellm._logging import verbose_proxy_logger from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.repositories.table_repositories import ( + WorkflowEventRepository, + WorkflowMessageRepository, + WorkflowRunRepository, +) router = APIRouter() @@ -96,13 +101,13 @@ class WorkflowMessageCreateRequest(BaseModel): async def _get_next_sequence_number(prisma_client: Any, run_id: str, table: str) -> int: """Return MAX(sequence_number) + 1 for the given run, for either events or messages.""" if table == "events": - rows = await prisma_client.db.litellm_workflowevent.find_many( + rows = await WorkflowEventRepository(prisma_client).table.find_many( where={"run_id": run_id}, order={"sequence_number": "desc"}, take=1, ) else: - rows = await prisma_client.db.litellm_workflowmessage.find_many( + rows = await WorkflowMessageRepository(prisma_client).table.find_many( where={"run_id": run_id}, order={"sequence_number": "desc"}, take=1, @@ -116,7 +121,7 @@ async def _require_run( user_api_key_dict: Optional[UserAPIKeyAuth] = None, ) -> Any: """Return the run or raise 404. For non-admin callers, also enforce key ownership.""" - run = await prisma_client.db.litellm_workflowrun.find_unique( + run = await WorkflowRunRepository(prisma_client).table.find_unique( where={"run_id": run_id} ) if run is None: @@ -163,7 +168,7 @@ async def create_workflow_run( create_data["input"] = _json(data.input) if data.metadata is not None: create_data["metadata"] = _json(data.metadata) - run = await prisma_client.db.litellm_workflowrun.create(data=create_data) + run = await WorkflowRunRepository(prisma_client).table.create(data=create_data) return run except Exception as e: verbose_proxy_logger.exception("Error creating workflow run: %s", e) @@ -206,7 +211,7 @@ async def list_workflow_runs( where["created_by"] = caller try: - runs = await prisma_client.db.litellm_workflowrun.find_many( + runs = await WorkflowRunRepository(prisma_client).table.find_many( where=where, order={"created_at": "desc"}, take=limit, @@ -235,7 +240,7 @@ async def get_workflow_run( ) try: - run = await prisma_client.db.litellm_workflowrun.find_unique( + run = await WorkflowRunRepository(prisma_client).table.find_unique( where={"run_id": run_id}, include={"events": {"order_by": {"sequence_number": "desc"}, "take": 1}}, ) @@ -286,7 +291,7 @@ async def update_workflow_run( await _require_run(prisma_client, run_id, user_api_key_dict) try: - run = await prisma_client.db.litellm_workflowrun.update( + run = await WorkflowRunRepository(prisma_client).table.update( where={"run_id": run_id}, data=update, ) @@ -391,7 +396,7 @@ async def list_workflow_events( await _require_run(prisma_client, run_id, user_api_key_dict) try: - events = await prisma_client.db.litellm_workflowevent.find_many( + events = await WorkflowEventRepository(prisma_client).table.find_many( where={"run_id": run_id}, order={"sequence_number": "asc"}, take=limit, @@ -436,7 +441,9 @@ async def append_workflow_message( } if data.session_id is not None: msg_data["session_id"] = data.session_id - msg = await prisma_client.db.litellm_workflowmessage.create(data=msg_data) + msg = await WorkflowMessageRepository(prisma_client).table.create( + data=msg_data + ) return msg except Exception as e: @@ -481,7 +488,7 @@ async def list_workflow_messages( await _require_run(prisma_client, run_id, user_api_key_dict) try: - messages = await prisma_client.db.litellm_workflowmessage.find_many( + messages = await WorkflowMessageRepository(prisma_client).table.find_many( where={"run_id": run_id}, order={"sequence_number": "asc"}, take=limit, diff --git a/litellm/proxy/management_helpers/audit_logs.py b/litellm/proxy/management_helpers/audit_logs.py index 439c3b2118..33599c3c62 100644 --- a/litellm/proxy/management_helpers/audit_logs.py +++ b/litellm/proxy/management_helpers/audit_logs.py @@ -18,6 +18,7 @@ from litellm.proxy._types import ( Optional, UserAPIKeyAuth, ) +from litellm.repositories.table_repositories import AuditLogRepository from litellm.types.utils import StandardAuditLogPayload _audit_log_callback_cache: Dict[str, CustomLogger] = {} @@ -244,7 +245,7 @@ async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): _request_data = request_data.model_dump(exclude_none=True) try: - await prisma_client.db.litellm_auditlog.create( + await AuditLogRepository(prisma_client).table.create( data={ **_request_data, # type: ignore } diff --git a/litellm/proxy/management_helpers/object_permission_utils.py b/litellm/proxy/management_helpers/object_permission_utils.py index 4c966b2541..f2ddae40d8 100644 --- a/litellm/proxy/management_helpers/object_permission_utils.py +++ b/litellm/proxy/management_helpers/object_permission_utils.py @@ -12,6 +12,8 @@ from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy.utils import PrismaClient +from litellm.repositories.object_permission_repository import ObjectPermissionRepository +from litellm.repositories.table_repositories import MCPServerRepository if TYPE_CHECKING: from litellm.proxy._types import ( @@ -48,10 +50,10 @@ async def attach_object_permission_to_dict( object_permission_id = data_dict.get("object_permission_id") if object_permission_id: - object_permission = ( - await prisma_client.db.litellm_objectpermissiontable.find_unique( - where={"object_permission_id": object_permission_id}, - ) + object_permission = await ObjectPermissionRepository( + prisma_client + ).table.find_unique( + where={"object_permission_id": object_permission_id}, ) if object_permission: # Convert to dict if needed @@ -106,10 +108,10 @@ async def handle_update_object_permission_common( ) existing_object_permissions_dict: Dict = {} - existing_object_permission = ( - await prisma_client.db.litellm_objectpermissiontable.find_unique( - where={"object_permission_id": object_permission_id_to_use}, - ) + existing_object_permission = await ObjectPermissionRepository( + prisma_client + ).table.find_unique( + where={"object_permission_id": object_permission_id_to_use}, ) # Update the object permission @@ -137,14 +139,14 @@ async def handle_update_object_permission_common( ######################################################### # Commit the update to the LiteLLM_ObjectPermissionTable ######################################################### - created_object_permission_row = ( - await prisma_client.db.litellm_objectpermissiontable.upsert( - where={"object_permission_id": object_permission_id_to_use}, - data={ - "create": existing_object_permissions_dict, - "update": existing_object_permissions_dict, - }, - ) + created_object_permission_row = await ObjectPermissionRepository( + prisma_client + ).table.upsert( + where={"object_permission_id": object_permission_id_to_use}, + data={ + "create": existing_object_permissions_dict, + "update": existing_object_permissions_dict, + }, ) verbose_proxy_logger.debug( @@ -183,7 +185,7 @@ async def _set_object_permission( clean_data["mcp_tool_permissions"] ) - created_permission = await prisma_client.db.litellm_objectpermissiontable.create( + created_permission = await ObjectPermissionRepository(prisma_client).table.create( data=clean_data ) @@ -220,7 +222,7 @@ async def _get_db_mcp_servers_by_identifiers( return [] identifier_list = list(identifiers) - return await prisma_client.db.litellm_mcpservertable.find_many( + return await MCPServerRepository(prisma_client).table.find_many( where={ "OR": [ {"server_id": {"in": identifier_list}}, diff --git a/litellm/proxy/management_helpers/user_invitation.py b/litellm/proxy/management_helpers/user_invitation.py index d2d800aa77..babc920189 100644 --- a/litellm/proxy/management_helpers/user_invitation.py +++ b/litellm/proxy/management_helpers/user_invitation.py @@ -4,6 +4,7 @@ from fastapi import HTTPException import litellm from litellm.proxy._types import CommonProxyErrors, InvitationNew, UserAPIKeyAuth +from litellm.repositories.table_repositories import InvitationLinkRepository async def create_invitation_for_user( @@ -25,7 +26,7 @@ async def create_invitation_for_user( expires_at = current_time + timedelta(days=7) try: - response = await prisma_client.db.litellm_invitationlink.create( + response = await InvitationLinkRepository(prisma_client).table.create( data={ "user_id": data.user_id, "created_at": current_time, diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 0b175db3c8..830d6f84b8 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -9,9 +9,8 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_logger -from litellm.integrations.otel.model.config import is_otel_v2_enabled from litellm._uuid import uuid -from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time +from litellm.integrations.otel.model.config import is_otel_v2_enabled from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types BudgetNewRequest, DeleteCustomerRequest, @@ -32,7 +31,11 @@ from litellm.proxy._types import ( # key request types; user request types; tea VirtualKeyEvent, ) from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time from litellm.proxy.utils import PrismaClient +from litellm.repositories.budget_repository import BudgetRepository +from litellm.repositories.table_repositories import TeamMembershipRepository +from litellm.repositories.user_repository import UserRepository def get_new_internal_user_defaults( @@ -111,7 +114,7 @@ async def handle_budget_for_entity( budget_row.model_dump(exclude_none=True) ) - _budget = await prisma_client.db.litellm_budgettable.create( + _budget = await BudgetRepository(prisma_client).table.create( data={ **new_budget_data, # type: ignore "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, @@ -174,7 +177,7 @@ async def _clone_team_default_budget_for_member( so the member starts with the team default's values but gets their own private budget row (which can be edited independently). """ - default_budget = await prisma_client.db.litellm_budgettable.find_unique( + default_budget = await BudgetRepository(prisma_client).table.find_unique( where={"budget_id": default_team_budget_id} ) if default_budget is None: @@ -202,7 +205,7 @@ async def _clone_team_default_budget_for_member( cloned_data["budget_duration"] ) - new_budget = await prisma_client.db.litellm_budgettable.create(data=cloned_data) + new_budget = await BudgetRepository(prisma_client).table.create(data=cloned_data) return new_budget.budget_id @@ -229,7 +232,7 @@ async def add_new_member( ## ADD TEAM ID, to USER TABLE IF NEW ## if new_member.user_id is not None: new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id) - _returned_user = await prisma_client.db.litellm_usertable.upsert( + _returned_user = await UserRepository(prisma_client).table.upsert( where={"user_id": new_member.user_id}, data={ "update": {"teams": {"push": [team_id]}}, @@ -259,7 +262,7 @@ async def add_new_member( returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) elif len(existing_user_row) == 1: user_info = existing_user_row[0] - _returned_user = await prisma_client.db.litellm_usertable.update( + _returned_user = await UserRepository(prisma_client).table.update( where={"user_id": user_info.user_id}, # type: ignore data={"teams": {"push": [team_id]}}, ) @@ -284,7 +287,7 @@ async def add_new_member( budget_data["max_budget"] = max_budget_in_team if allowed_models is not None: budget_data["allowed_models"] = allowed_models - response = await prisma_client.db.litellm_budgettable.create(data=budget_data) + response = await BudgetRepository(prisma_client).table.create(data=budget_data) _budget_id = response.budget_id elif default_team_budget_id is not None: @@ -303,15 +306,15 @@ async def add_new_member( _budget_id = None if _budget_id and returned_user is not None and returned_user.user_id is not None: - _returned_team_membership = ( - await prisma_client.db.litellm_teammembership.create( - data={ - "team_id": team_id, - "user_id": returned_user.user_id, - "budget_id": _budget_id, - }, - include={"litellm_budget_table": True}, - ) + _returned_team_membership = await TeamMembershipRepository( + prisma_client + ).table.create( + data={ + "team_id": team_id, + "user_id": returned_user.user_id, + "budget_id": _budget_id, + }, + include={"litellm_budget_table": True}, ) returned_team_membership = LiteLLM_TeamMembership( diff --git a/litellm/proxy/memory/memory_endpoints.py b/litellm/proxy/memory/memory_endpoints.py index 4d161be426..6f1ca3196f 100644 --- a/litellm/proxy/memory/memory_endpoints.py +++ b/litellm/proxy/memory/memory_endpoints.py @@ -29,6 +29,8 @@ from litellm.proxy._types import ( UserAPIKeyAuth, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.repositories.table_repositories import MemoryRepository +from litellm.repositories.team_repository import TeamRepository from litellm.types.memory_management import ( LiteLLM_MemoryRow, MemoryCreateRequest, @@ -173,7 +175,7 @@ async def _is_team_admin_for( ) try: - team_obj = await prisma_client.db.litellm_teamtable.find_unique( + team_obj = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) except Exception as e: @@ -304,7 +306,7 @@ async def create_memory( create_data["metadata"] = _serialize_metadata_for_prisma(body.metadata) try: - row = await prisma_client.db.litellm_memorytable.create(data=create_data) + row = await MemoryRepository(prisma_client).table.create(data=create_data) except Exception as e: # Key is globally unique. Any duplicate → 409. if _is_unique_violation(e): @@ -364,8 +366,8 @@ async def list_memory( where = {"AND": [key_filter, vis]} try: - total = await prisma_client.db.litellm_memorytable.count(where=where) - rows = await prisma_client.db.litellm_memorytable.find_many( + total = await MemoryRepository(prisma_client).table.count(where=where) + rows = await MemoryRepository(prisma_client).table.find_many( where=where, order={"updated_at": "desc"}, skip=(page - 1) * page_size, @@ -386,7 +388,7 @@ async def _find_memory_for_caller( key_filter: dict = {"key": key} vis = _visibility_filter(user_api_key_dict) where: dict = key_filter if vis is None else {"AND": [key_filter, vis]} - rows = await prisma_client.db.litellm_memorytable.find_many( + rows = await MemoryRepository(prisma_client).table.find_many( where=where, take=1, order={"updated_at": "desc"} ) if not rows: @@ -475,7 +477,7 @@ async def upsert_memory( # their team) — otherwise a teammate could overwrite a personal # entry through the OR-based visibility filter. await _assert_write_access(prisma_client, existing, user_api_key_dict) - row = await prisma_client.db.litellm_memorytable.update( + row = await MemoryRepository(prisma_client).table.update( where={"memory_id": existing.memory_id}, data=data, ) @@ -503,7 +505,7 @@ async def upsert_memory( if body.metadata is not None: create_data["metadata"] = _serialize_metadata_for_prisma(body.metadata) try: - row = await prisma_client.db.litellm_memorytable.create( + row = await MemoryRepository(prisma_client).table.create( data=create_data ) except Exception as e: @@ -524,7 +526,7 @@ async def upsert_memory( await _assert_write_access( prisma_client, existing_after_race, user_api_key_dict ) - row = await prisma_client.db.litellm_memorytable.update( + row = await MemoryRepository(prisma_client).table.update( where={"memory_id": existing_after_race.memory_id}, data=data, ) @@ -554,7 +556,7 @@ async def delete_memory( # Visibility != write authority — see the upsert handler for the rationale. await _assert_write_access(prisma_client, row, user_api_key_dict) try: - await prisma_client.db.litellm_memorytable.delete( + await MemoryRepository(prisma_client).table.delete( where={"memory_id": row.memory_id} ) except Exception as e: diff --git a/litellm/proxy/openai_files_endpoints/common_utils.py b/litellm/proxy/openai_files_endpoints/common_utils.py index cc0d06e4f4..b2834e5230 100644 --- a/litellm/proxy/openai_files_endpoints/common_utils.py +++ b/litellm/proxy/openai_files_endpoints/common_utils.py @@ -5,6 +5,10 @@ from dataclasses import dataclass, field from types import MappingProxyType from typing import TYPE_CHECKING, List, Literal, Optional, Union +from litellm.repositories.table_repositories import ( + ManagedFileRepository, + ManagedObjectRepository, +) from litellm.types.utils import SpecialEnums if TYPE_CHECKING: @@ -697,7 +701,7 @@ async def resolve_input_file_id_to_unified(response, prisma_client) -> None: and prisma_client ): try: - managed_file = await prisma_client.db.litellm_managedfiletable.find_first( + managed_file = await ManagedFileRepository(prisma_client).table.find_first( where={"flat_model_file_ids": {"has": response.input_file_id}} ) if managed_file: @@ -719,7 +723,7 @@ async def resolve_output_file_ids_to_unified(response, prisma_client) -> None: if not raw_id or _is_base64_encoded_unified_file_id(raw_id): continue try: - managed_file = await prisma_client.db.litellm_managedfiletable.find_first( + managed_file = await ManagedFileRepository(prisma_client).table.find_first( where={"flat_model_file_ids": {"has": raw_id}} ) if managed_file: @@ -821,6 +825,7 @@ async def get_batch_from_database( - response_batch: Parsed LiteLLMBatch object (or None) """ import json + from litellm.types.utils import LiteLLMBatch if managed_files_obj is None or not unified_batch_id: @@ -830,7 +835,7 @@ async def get_batch_from_database( if not prisma_client: return None, None - db_batch_object = await prisma_client.db.litellm_managedobjecttable.find_first( + db_batch_object = await ManagedObjectRepository(prisma_client).table.find_first( where={"unified_object_id": batch_id} ) @@ -942,7 +947,7 @@ async def update_batch_in_database( update_data["batch_processed"] = True try: - await prisma_client.db.litellm_managedobjecttable.update( + await ManagedObjectRepository(prisma_client).table.update( where={"unified_object_id": batch_id}, data=update_data, ) @@ -958,7 +963,7 @@ async def update_batch_in_database( f"batch_processed column not found, retrying update without it: {col_err}" ) update_data.pop("batch_processed", None) - await prisma_client.db.litellm_managedobjecttable.update( + await ManagedObjectRepository(prisma_client).table.update( where={"unified_object_id": batch_id}, data=update_data, ) diff --git a/litellm/proxy/openai_files_endpoints/files_endpoints.py b/litellm/proxy/openai_files_endpoints/files_endpoints.py index 378cbbda89..9eef7cd7e8 100644 --- a/litellm/proxy/openai_files_endpoints/files_endpoints.py +++ b/litellm/proxy/openai_files_endpoints/files_endpoints.py @@ -21,6 +21,7 @@ from fastapi import ( UploadFile, status, ) + import litellm from litellm import CreateFileRequest, get_secret_str from litellm._logging import verbose_proxy_logger @@ -37,15 +38,6 @@ from litellm.proxy.common_utils.openai_endpoint_utils import ( get_custom_llm_provider_from_request_headers, get_custom_llm_provider_from_request_query, ) -from litellm.proxy.utils import ProxyLogging, is_known_model -from litellm.router import Router -from litellm.types.llms.openai import ( - CREATE_FILE_REQUESTS_PURPOSE, - FileExpiresAfter, - OpenAIFileObject, - OpenAIFilesPurpose, -) - from litellm.proxy.openai_files_endpoints.common_utils import ( _is_base64_encoded_unified_file_id, encode_file_id_with_model, @@ -54,6 +46,15 @@ from litellm.proxy.openai_files_endpoints.common_utils import ( handle_model_based_routing, prepare_data_with_credentials, ) +from litellm.proxy.utils import ProxyLogging, is_known_model +from litellm.repositories.table_repositories import ManagedFileRepository +from litellm.router import Router +from litellm.types.llms.openai import ( + CREATE_FILE_REQUESTS_PURPOSE, + FileExpiresAfter, + OpenAIFileObject, + OpenAIFilesPurpose, +) router = APIRouter() @@ -666,7 +667,7 @@ async def get_file_content( # noqa: PLR0915 managed_files_obj, "prisma_client", None ): prisma_client = getattr(managed_files_obj, "prisma_client") - db_file = await prisma_client.db.litellm_managedfiletable.find_first( + db_file = await ManagedFileRepository(prisma_client).table.find_first( where={"unified_file_id": file_id} ) if db_file and db_file.storage_backend and db_file.storage_url: diff --git a/litellm/proxy/pass_through_endpoints/managed_id_rewriter.py b/litellm/proxy/pass_through_endpoints/managed_id_rewriter.py index a267c97c0e..9c0fbe30fc 100644 --- a/litellm/proxy/pass_through_endpoints/managed_id_rewriter.py +++ b/litellm/proxy/pass_through_endpoints/managed_id_rewriter.py @@ -43,6 +43,10 @@ from litellm.llms.base_llm.managed_resources.isolation import ( can_access_resource, ) from litellm.proxy._types import UserAPIKeyAuth +from litellm.repositories.table_repositories import ( + ManagedFileRepository, + ManagedObjectRepository, +) from litellm.types.llms.openai import OpenAIFileObject from .managed_id_codec import ManagedIdPayload, decode, is_managed, new_managed_id @@ -323,7 +327,7 @@ async def _resolve_one( ) if not found and prisma_client is not None: try: - db_row = await prisma_client.db.litellm_managedfiletable.find_first( + db_row = await ManagedFileRepository(prisma_client).table.find_first( where={"unified_file_id": managed_id} ) if db_row is not None: @@ -339,7 +343,7 @@ async def _resolve_one( # Object table (batches, responses) if prisma_client is not None: try: - obj_row = await prisma_client.db.litellm_managedobjecttable.find_first( + obj_row = await ManagedObjectRepository(prisma_client).table.find_first( where={"unified_object_id": managed_id} ) if obj_row is not None: @@ -399,7 +403,7 @@ async def _guard_raw_provider_id( # id and scope to the current provider in the application layer (same as # _mint_or_reuse_file's dedup). try: - candidates = await prisma_client.db.litellm_managedfiletable.find_many( + candidates = await ManagedFileRepository(prisma_client).table.find_many( where={"flat_model_file_ids": {"has": raw_id}}, ) except Exception: @@ -425,7 +429,7 @@ async def _guard_raw_provider_id( # Object rows store model_object_id as "passthrough:{provider}:{raw}", so # the lookup is exact and already provider-scoped. try: - existing = await prisma_client.db.litellm_managedobjecttable.find_first( + existing = await ManagedObjectRepository(prisma_client).table.find_first( where={"model_object_id": f"passthrough:{provider}:{raw_id}"} ) except Exception: @@ -492,7 +496,7 @@ async def _mint_or_reuse_file( # reuse a stable row instead of minting duplicate rows on every call. if prisma_client is not None: try: - candidates = await prisma_client.db.litellm_managedfiletable.find_many( + candidates = await ManagedFileRepository(prisma_client).table.find_many( where={"flat_model_file_ids": {"has": raw_id}}, order={"created_at": "asc"}, ) @@ -627,7 +631,7 @@ async def _mint_or_reuse_object( # the batch's latest state (e.g. output_file_id / error_file_id that # were null at creation but populated once the batch completed). try: - await prisma_client.db.litellm_managedobjecttable.update( + await ManagedObjectRepository(prisma_client).table.update( where={"unified_object_id": existing.unified_object_id}, data={ "file_object": json.dumps(body_snapshot), @@ -647,7 +651,7 @@ async def _mint_or_reuse_object( # Dedup: look up by the namespaced key — guaranteed unique per provider. try: - existing = await prisma_client.db.litellm_managedobjecttable.find_first( + existing = await ManagedObjectRepository(prisma_client).table.find_first( where={"model_object_id": namespaced_model_object_id} ) except Exception: @@ -666,7 +670,7 @@ async def _mint_or_reuse_object( raw_id.split("_", 1)[0], ) try: - await prisma_client.db.litellm_managedobjecttable.upsert( + await ManagedObjectRepository(prisma_client).table.upsert( where={"unified_object_id": managed_id}, data={ "create": { @@ -690,7 +694,7 @@ async def _mint_or_reuse_object( # the winner's managed ID so both callers converge on one ID instead of # the loser silently keeping the raw id. try: - raced = await prisma_client.db.litellm_managedobjecttable.find_first( + raced = await ManagedObjectRepository(prisma_client).table.find_first( where={"model_object_id": namespaced_model_object_id} ) except Exception: @@ -883,9 +887,9 @@ async def _build_list_where_with_cursor( return where, fetch_order cursor_table = ( - prisma_client.db.litellm_managedfiletable + ManagedFileRepository(prisma_client).table if resource_kind == "files" - else prisma_client.db.litellm_managedobjecttable + else ManagedObjectRepository(prisma_client).table ) cursor_field = ( "unified_file_id" if resource_kind == "files" else "unified_object_id" @@ -932,12 +936,12 @@ async def _fetch_list_rows( # across rows that share a created_at timestamp. try: if resource_kind == "files": - return await prisma_client.db.litellm_managedfiletable.find_many( + return await ManagedFileRepository(prisma_client).table.find_many( where=where, order=[{"created_at": fetch_order}, {"unified_file_id": fetch_order}], take=fetch_limit, ) - return await prisma_client.db.litellm_managedobjecttable.find_many( + return await ManagedObjectRepository(prisma_client).table.find_many( where={**where, "file_purpose": "batch"}, order=[{"created_at": fetch_order}, {"unified_object_id": fetch_order}], take=fetch_limit, diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 6667010447..45e264b1cd 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -60,12 +60,13 @@ from litellm.proxy.common_utils.http_parsing_utils import ( ) from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup from litellm.proxy.utils import normalize_route_for_root_path +from litellm.repositories.team_repository import TeamRepository from litellm.secret_managers.main import get_secret_str from litellm.types.llms.custom_http import httpxSpecialProvider from litellm.types.passthrough_endpoints.pass_through_endpoints import ( - EndpointType, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY, + EndpointType, PassthroughStandardLoggingPayload, ) @@ -1002,9 +1003,7 @@ async def pass_through_request( # noqa: PLR0915 is_passthrough_list_route, list_passthrough_ids_from_db, ) - from litellm.proxy.proxy_server import ( - prisma_client as _list_prisma, - ) + from litellm.proxy.proxy_server import prisma_client as _list_prisma if ( is_passthrough_list_route( @@ -2875,7 +2874,7 @@ async def _filter_endpoints_by_team_allowed_routes( HTTPException: If team is not found """ # retrieve team from db - team = await prisma_client.db.litellm_teamtable.find_unique( + team = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id}, ) if team is None: diff --git a/litellm/proxy/policy_engine/attachment_registry.py b/litellm/proxy/policy_engine/attachment_registry.py index 8d5d811691..fb1e2652e8 100644 --- a/litellm/proxy/policy_engine/attachment_registry.py +++ b/litellm/proxy/policy_engine/attachment_registry.py @@ -9,6 +9,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional from litellm._logging import verbose_proxy_logger +from litellm.repositories.table_repositories import PolicyAttachmentRepository from litellm.types.proxy.policy_engine import ( PolicyAttachment, PolicyAttachmentCreateRequest, @@ -278,21 +279,21 @@ class AttachmentRegistry: PolicyAttachmentDBResponse with the created attachment """ try: - created_attachment = ( - await prisma_client.db.litellm_policyattachmenttable.create( - data={ - "policy_name": attachment_request.policy_name, - "scope": attachment_request.scope, - "teams": attachment_request.teams or [], - "keys": attachment_request.keys or [], - "models": attachment_request.models or [], - "tags": attachment_request.tags or [], - "created_at": datetime.now(timezone.utc), - "updated_at": datetime.now(timezone.utc), - "created_by": created_by, - "updated_by": created_by, - } - ) + created_attachment = await PolicyAttachmentRepository( + prisma_client + ).table.create( + data={ + "policy_name": attachment_request.policy_name, + "scope": attachment_request.scope, + "teams": attachment_request.teams or [], + "keys": attachment_request.keys or [], + "models": attachment_request.models or [], + "tags": attachment_request.tags or [], + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + "created_by": created_by, + "updated_by": created_by, + } ) # Also add to in-memory registry @@ -340,17 +341,15 @@ class AttachmentRegistry: """ try: # Get attachment before deleting - attachment = ( - await prisma_client.db.litellm_policyattachmenttable.find_unique( - where={"attachment_id": attachment_id} - ) - ) + attachment = await PolicyAttachmentRepository( + prisma_client + ).table.find_unique(where={"attachment_id": attachment_id}) if attachment is None: raise Exception(f"Attachment with ID {attachment_id} not found") # Delete from DB - await prisma_client.db.litellm_policyattachmenttable.delete( + await PolicyAttachmentRepository(prisma_client).table.delete( where={"attachment_id": attachment_id} ) @@ -379,11 +378,9 @@ class AttachmentRegistry: PolicyAttachmentDBResponse if found, None otherwise """ try: - attachment = ( - await prisma_client.db.litellm_policyattachmenttable.find_unique( - where={"attachment_id": attachment_id} - ) - ) + attachment = await PolicyAttachmentRepository( + prisma_client + ).table.find_unique(where={"attachment_id": attachment_id}) if attachment is None: return None @@ -419,10 +416,10 @@ class AttachmentRegistry: List of PolicyAttachmentDBResponse objects """ try: - attachments = ( - await prisma_client.db.litellm_policyattachmenttable.find_many( - order={"created_at": "desc"}, - ) + attachments = await PolicyAttachmentRepository( + prisma_client + ).table.find_many( + order={"created_at": "desc"}, ) return [ diff --git a/litellm/proxy/policy_engine/policy_registry.py b/litellm/proxy/policy_engine/policy_registry.py index 75017c4660..d626551626 100644 --- a/litellm/proxy/policy_engine/policy_registry.py +++ b/litellm/proxy/policy_engine/policy_registry.py @@ -12,6 +12,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from litellm._logging import verbose_proxy_logger +from litellm.repositories.table_repositories import PolicyRepository from litellm.types.proxy.policy_engine import ( GuardrailPipeline, PipelineStep, @@ -295,7 +296,7 @@ class PolicyRegistry: validated_pipeline = GuardrailPipeline(**policy_request.pipeline) data["pipeline"] = json.dumps(validated_pipeline.model_dump()) - created_policy = await prisma_client.db.litellm_policytable.create( + created_policy = await PolicyRepository(prisma_client).table.create( data=data ) @@ -347,7 +348,7 @@ class PolicyRegistry: Exception: If policy is not in draft status (only drafts are editable). """ try: - existing = await prisma_client.db.litellm_policytable.find_unique( + existing = await PolicyRepository(prisma_client).table.find_unique( where={"policy_id": policy_id} ) if existing is None: @@ -382,7 +383,7 @@ class PolicyRegistry: validated_pipeline = GuardrailPipeline(**policy_request.pipeline) update_data["pipeline"] = json.dumps(validated_pipeline.model_dump()) - updated_policy = await prisma_client.db.litellm_policytable.update( + updated_policy = await PolicyRepository(prisma_client).table.update( where={"policy_id": policy_id}, data=update_data, ) @@ -413,7 +414,7 @@ class PolicyRegistry: Dict with "message" and optional "warning" if production was deleted. """ try: - policy = await prisma_client.db.litellm_policytable.find_unique( + policy = await PolicyRepository(prisma_client).table.find_unique( where={"policy_id": policy_id} ) @@ -424,7 +425,7 @@ class PolicyRegistry: policy_name = policy.policy_name # Delete from DB - await prisma_client.db.litellm_policytable.delete( + await PolicyRepository(prisma_client).table.delete( where={"policy_id": policy_id} ) @@ -461,7 +462,7 @@ class PolicyRegistry: PolicyDBResponse if found, None otherwise """ try: - policy = await prisma_client.db.litellm_policytable.find_unique( + policy = await PolicyRepository(prisma_client).table.find_unique( where={"policy_id": policy_id} ) @@ -512,7 +513,7 @@ class PolicyRegistry: if version_status is not None: where["version_status"] = version_status - policies = await prisma_client.db.litellm_policytable.find_many( + policies = await PolicyRepository(prisma_client).table.find_many( where=where if where else None, order={"created_at": "desc"}, ) @@ -554,7 +555,7 @@ class PolicyRegistry: self.add_policy(policy_response.policy_name, policy) self._policies_by_id = {} - non_production = await prisma_client.db.litellm_policytable.find_many( + non_production = await PolicyRepository(prisma_client).table.find_many( where={"version_status": {"in": ["draft", "published"]}}, order={"created_at": "desc"}, ) @@ -654,7 +655,7 @@ class PolicyRegistry: PolicyVersionListResponse with policy_name and list of versions """ try: - rows = await prisma_client.db.litellm_policytable.find_many( + rows = await PolicyRepository(prisma_client).table.find_many( where={"policy_name": policy_name}, order={"version_number": "desc"}, ) @@ -690,7 +691,7 @@ class PolicyRegistry: """ try: if source_policy_id is not None: - source = await prisma_client.db.litellm_policytable.find_unique( + source = await PolicyRepository(prisma_client).table.find_unique( where={"policy_id": source_policy_id} ) if source is None: @@ -701,7 +702,7 @@ class PolicyRegistry: ) else: # Find current production version for this policy_name - prod = await prisma_client.db.litellm_policytable.find_first( + prod = await PolicyRepository(prisma_client).table.find_first( where={ "policy_name": policy_name, "version_status": "production", @@ -714,7 +715,7 @@ class PolicyRegistry: source = prod # Next version number - latest = await prisma_client.db.litellm_policytable.find_first( + latest = await PolicyRepository(prisma_client).table.find_first( where={"policy_name": policy_name}, order={"version_number": "desc"}, ) @@ -722,7 +723,7 @@ class PolicyRegistry: now = datetime.now(timezone.utc) # Set is_latest=False on all existing versions for this policy_name - await prisma_client.db.litellm_policytable.update_many( + await PolicyRepository(prisma_client).table.update_many( where={"policy_name": policy_name}, data={"is_latest": False}, ) @@ -758,7 +759,7 @@ class PolicyRegistry: else source.pipeline ) - created = await prisma_client.db.litellm_policytable.create(data=data) + created = await PolicyRepository(prisma_client).table.create(data=data) return _row_to_policy_db_response(created) except Exception as e: verbose_proxy_logger.exception(f"Error creating new version: {e}") @@ -794,7 +795,7 @@ class PolicyRegistry: f"Invalid status '{new_status}'. Use 'published' or 'production'." ) - row = await prisma_client.db.litellm_policytable.find_unique( + row = await PolicyRepository(prisma_client).table.find_unique( where={"policy_id": policy_id} ) if row is None: @@ -809,7 +810,7 @@ class PolicyRegistry: raise Exception( f"Only draft versions can be published. Current status: '{current}'." ) - updated = await prisma_client.db.litellm_policytable.update( + updated = await PolicyRepository(prisma_client).table.update( where={"policy_id": policy_id}, data={ "version_status": "published", @@ -832,7 +833,7 @@ class PolicyRegistry: ) # Demote current production to published - await prisma_client.db.litellm_policytable.update_many( + await PolicyRepository(prisma_client).table.update_many( where={ "policy_name": policy_name, "version_status": "production", @@ -845,7 +846,7 @@ class PolicyRegistry: ) # Promote this version to production - updated = await prisma_client.db.litellm_policytable.update( + updated = await PolicyRepository(prisma_client).table.update( where={"policy_id": policy_id}, data={ "version_status": "production", @@ -895,10 +896,10 @@ class PolicyRegistry: PolicyVersionCompareResponse with both versions and field_diffs """ try: - a = await prisma_client.db.litellm_policytable.find_unique( + a = await PolicyRepository(prisma_client).table.find_unique( where={"policy_id": policy_id_a} ) - b = await prisma_client.db.litellm_policytable.find_unique( + b = await PolicyRepository(prisma_client).table.find_unique( where={"policy_id": policy_id_b} ) if a is None: @@ -950,7 +951,7 @@ class PolicyRegistry: Dict with success message """ try: - await prisma_client.db.litellm_policytable.delete_many( + await PolicyRepository(prisma_client).table.delete_many( where={"policy_name": policy_name} ) self.remove_policy(policy_name) diff --git a/litellm/proxy/policy_engine/policy_resolve_endpoints.py b/litellm/proxy/policy_engine/policy_resolve_endpoints.py index 54374d90a1..84dcbcfd74 100644 --- a/litellm/proxy/policy_engine/policy_resolve_endpoints.py +++ b/litellm/proxy/policy_engine/policy_resolve_endpoints.py @@ -16,6 +16,10 @@ from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry from litellm.proxy.policy_engine.policy_registry import get_policy_registry +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.types.proxy.policy_engine import ( AttachmentImpactResponse, PolicyAttachmentCreateRequest, @@ -76,7 +80,7 @@ def _get_tags_from_metadata(metadata: object, json_metadata: object = None) -> l async def _fetch_all_teams(prisma_client: object) -> list: """Fetch teams from DB once. Reuse the result across tag and alias lookups.""" - return await prisma_client.db.litellm_teamtable.find_many( # type: ignore + return await TeamRepository(prisma_client).table.find_many( # type: ignore where={}, order={"created_at": "desc"}, take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, @@ -159,7 +163,7 @@ async def _find_affected_by_team_patterns( new_keys: list = [] unnamed_keys_count = 0 if matched_team_ids: - keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore + keys = await VerificationTokenRepository(prisma_client).table.find_many( # type: ignore where={"team_id": {"in": matched_team_ids}}, order={"created_at": "desc"}, take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, @@ -182,7 +186,7 @@ async def _find_affected_keys_by_alias( affected: list = [] - keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore + keys = await VerificationTokenRepository(prisma_client).table.find_many( # type: ignore where=_build_alias_where("key_alias", key_patterns), order={"created_at": "desc"}, take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, @@ -367,7 +371,7 @@ async def estimate_attachment_impact( # Tag-based impact if tag_patterns: - keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore + keys = await VerificationTokenRepository(prisma_client).table.find_many( # type: ignore where={}, order={"created_at": "desc"}, take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, diff --git a/litellm/proxy/policy_engine/policy_validator.py b/litellm/proxy/policy_engine/policy_validator.py index b587e3432b..46796fbae2 100644 --- a/litellm/proxy/policy_engine/policy_validator.py +++ b/litellm/proxy/policy_engine/policy_validator.py @@ -12,6 +12,10 @@ Validates: from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set from litellm._logging import verbose_proxy_logger +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.types.proxy.policy_engine import ( Policy, PolicyValidationError, @@ -95,7 +99,7 @@ class PolicyValidator: return True # Can't validate without DB, assume valid try: - team = await self.prisma_client.db.litellm_teamtable.find_first( + team = await TeamRepository(self.prisma_client).table.find_first( where={"team_alias": team_alias}, ) return team is not None @@ -119,7 +123,9 @@ class PolicyValidator: return True # Can't validate without DB, assume valid try: - key = await self.prisma_client.db.litellm_verificationtoken.find_first( + key = await VerificationTokenRepository( + self.prisma_client + ).table.find_first( where={"key_alias": key_alias}, ) return key is not None diff --git a/litellm/proxy/prompts/prompt_endpoints.py b/litellm/proxy/prompts/prompt_endpoints.py index 399a0ff3af..c0d6794108 100644 --- a/litellm/proxy/prompts/prompt_endpoints.py +++ b/litellm/proxy/prompts/prompt_endpoints.py @@ -22,6 +22,7 @@ from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKey from litellm.proxy.auth.auth_utils import is_request_body_safe from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.path_utils import safe_filename +from litellm.repositories.table_repositories import PromptRepository from litellm.types.prompts.init_prompts import ( ListPromptsResponse, PromptInfo, @@ -208,7 +209,7 @@ async def get_next_version_for_prompt( Returns: Next version number (1 if no versions exist, max_version + 1 otherwise) """ - existing_prompts = await prisma_client.db.litellm_prompttable.find_many( + existing_prompts = await PromptRepository(prisma_client).table.find_many( where={"prompt_id": prompt_id, "environment": environment} ) @@ -441,7 +442,7 @@ async def get_prompt_versions( where_clause: Dict[str, Any] = {"prompt_id": base_prompt_id} if environment: where_clause["environment"] = environment - db_prompts = await prisma_client.db.litellm_prompttable.find_many( + db_prompts = await PromptRepository(prisma_client).table.find_many( where=where_clause, order={"version": "desc"}, ) @@ -612,7 +613,7 @@ async def get_prompt_info( # Query all environments this prompt exists in (lightweight: distinct on environment) all_environments: List[str] = [] if prisma_client is not None: - all_prompt_rows = await prisma_client.db.litellm_prompttable.find_many( + all_prompt_rows = await PromptRepository(prisma_client).table.find_many( where={"prompt_id": base_prompt_id}, distinct=["environment"], ) @@ -634,7 +635,7 @@ async def get_prompt_info( } if requested_version is not None: where_clause["version"] = requested_version - env_prompts = await prisma_client.db.litellm_prompttable.find_many( + env_prompts = await PromptRepository(prisma_client).table.find_many( where=where_clause, order={"version": "desc"}, take=1, @@ -752,7 +753,7 @@ async def create_prompt( ) # Store prompt in db with version - prompt_db_entry = await prisma_client.db.litellm_prompttable.create( + prompt_db_entry = await PromptRepository(prisma_client).table.create( data={ "prompt_id": request.prompt_id, "version": new_version, @@ -848,7 +849,7 @@ async def update_prompt( ) # Check if any version of this prompt exists (in any environment) - existing_prompts = await prisma_client.db.litellm_prompttable.find_many( + existing_prompts = await PromptRepository(prisma_client).table.find_many( where={"prompt_id": base_prompt_id} ) @@ -877,7 +878,7 @@ async def update_prompt( ) # Store new version in db - prompt_db_entry = await prisma_client.db.litellm_prompttable.create( + prompt_db_entry = await PromptRepository(prisma_client).table.create( data={ "prompt_id": base_prompt_id, "version": new_version, @@ -993,7 +994,7 @@ async def delete_prompt( delete_where["environment"] = environment # Delete versions from the database (scoped to environment if provided) - await prisma_client.db.litellm_prompttable.delete_many(where=delete_where) + await PromptRepository(prisma_client).table.delete_many(where=delete_where) # Remove matching prompts from memory — scope to environment if provided if environment: @@ -1105,7 +1106,7 @@ async def patch_prompt( if requested_version is not None: find_where["version"] = requested_version - db_rows = await prisma_client.db.litellm_prompttable.find_many( + db_rows = await PromptRepository(prisma_client).table.find_many( where=find_where, order={"version": "desc"}, take=1, @@ -1163,7 +1164,7 @@ async def patch_prompt( update_data["created_by"] = user_api_key_dict.user_id # Update by primary key (id) to target exactly one row - updated_prompt_db_entry = await prisma_client.db.litellm_prompttable.update( + updated_prompt_db_entry = await PromptRepository(prisma_client).table.update( where={"id": target_row.id}, data=update_data, ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 72423b2a79..2c21e19dce 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -48,6 +48,7 @@ from litellm.constants import ( AIOHTTP_TTL_DNS_CACHE, AUDIO_SPEECH_CHUNK_SIZE, BASE_MCP_ROUTE, + DAILY_TAG_SPEND_BATCH_MULTIPLIER, DEFAULT_MAX_RECURSE_DEPTH, DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL, DEFAULT_SHARED_HEALTH_CHECK_TTL, @@ -56,13 +57,13 @@ from litellm.constants import ( LITELLM_SETTINGS_SAFE_DB_OVERRIDES, LITELLM_UI_ALLOW_HEADERS, LITELLM_UI_SESSION_DURATION, - DAILY_TAG_SPEND_BATCH_MULTIPLIER, ) from litellm.litellm_core_utils.litellm_logging import ( _init_custom_logger_compatible_class, ) from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import ( + UI_TEAM_ID, CallbackDelete, CallInfo, CommonProxyErrors, @@ -79,8 +80,8 @@ from litellm.proxy._types import ( InvitationModel, InvitationNew, InvitationUpdate, - Litellm_EntityType, LiteLLM_EndUserTable, + Litellm_EntityType, LiteLLM_JWTAuth, LiteLLM_TagTable, LiteLLM_TeamTable, @@ -96,7 +97,6 @@ from litellm.proxy._types import ( TeamDefaultSettings, TokenCountRequest, TransformRequestBody, - UI_TEAM_ID, UserAPIKeyAuth, ) from litellm.proxy.common_utils.cache_pydantic_utils import CacheCodec @@ -212,8 +212,6 @@ from litellm import Router from litellm._logging import verbose_proxy_logger, verbose_router_logger from litellm.caching.caching import DualCache, RedisCache from litellm.caching.redis_cluster_cache import RedisClusterCache -from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time -from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.constants import ( _REALTIME_BODY_CACHE_SIZE, APSCHEDULER_COALESCE, @@ -247,8 +245,8 @@ from litellm.litellm_core_utils.sensitive_data_masker import ( ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.vertex_ai.vertex_llm_base import VertexBase -from litellm.proxy._types import * from litellm.proxy._lazy_features import attach_lazy_features +from litellm.proxy._types import * from litellm.proxy.analytics_endpoints.analytics_endpoints import ( router as analytics_router, ) @@ -308,6 +306,8 @@ from litellm.proxy.common_utils.openai_endpoint_utils import ( from litellm.proxy.common_utils.proxy_state import ProxyState from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES +from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time +from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.container_endpoints.endpoints import router as container_router from litellm.proxy.credential_endpoints.endpoints import router as credential_router from litellm.proxy.db.db_transaction_queue.spend_log_cleanup import SpendLogCleanup @@ -361,7 +361,9 @@ from litellm.proxy.management_endpoints.fallback_management_endpoints import ( from litellm.proxy.management_endpoints.internal_user_endpoints import ( router as internal_user_router, ) -from litellm.proxy.management_endpoints.internal_user_endpoints import user_update +from litellm.proxy.management_endpoints.internal_user_endpoints import ( + user_update, +) from litellm.proxy.management_endpoints.key_management_endpoints import ( delete_verification_tokens, duration_in_seconds, @@ -398,10 +400,6 @@ from litellm.proxy.management_endpoints.team_endpoints import ( update_team, validate_membership, ) -from litellm.proxy.management_endpoints.workflow_management_endpoints import ( - router as workflow_management_router, -) -from litellm.proxy.memory.memory_endpoints import router as memory_router from litellm.proxy.management_endpoints.ui_sso import ( get_disabled_non_admin_personal_key_creation, ) @@ -409,7 +407,11 @@ from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router from litellm.proxy.management_endpoints.user_agent_analytics_endpoints import ( router as user_agent_analytics_router, ) +from litellm.proxy.management_endpoints.workflow_management_endpoints import ( + router as workflow_management_router, +) from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update +from litellm.proxy.memory.memory_endpoints import router as memory_router from litellm.proxy.middleware.in_flight_requests_middleware import ( InFlightRequestsMiddleware, ) @@ -417,12 +419,13 @@ from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMi from litellm.proxy.middleware.request_size_limit_middleware import ( RequestSizeLimitMiddleware, ) -from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager from litellm.proxy.ocr_endpoints.endpoints import router as ocr_router from litellm.proxy.openai_files_endpoints.files_endpoints import ( router as openai_files_router, ) -from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config +from litellm.proxy.openai_files_endpoints.files_endpoints import ( + set_files_config, +) from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( passthrough_endpoint_router, ) @@ -444,6 +447,7 @@ from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router from litellm.proxy.response_api_endpoints.endpoints import router as response_router from litellm.proxy.route_llm_request import route_request from litellm.proxy.search_endpoints.endpoints import router as search_router +from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager from litellm.proxy.spend_tracking.spend_management_endpoints import ( router as spend_management_router, ) @@ -478,6 +482,7 @@ from litellm.proxy.utils import ( update_spend, ) from litellm.proxy.video_endpoints.endpoints import router as video_router +from litellm.repositories.credentials_repository import CredentialsRepository from litellm.router import ( AssistantsTypedDict, Deployment, @@ -511,7 +516,9 @@ from litellm.types.proxy.management_endpoints.ui_sso import ( LiteLLM_UpperboundKeyGenerateParams, ) from litellm.types.realtime import RealtimeQueryParams -from litellm.types.router import DeploymentTypedDict +from litellm.types.router import ( + DeploymentTypedDict, +) from litellm.types.router import ModelInfo as RouterModelInfo from litellm.types.router import ( RouterGeneralSettings, @@ -4248,13 +4255,13 @@ class ProxyConfig: ) setattr(litellm, key, value) if key in {"s3_audit_callback_params", "s3_callback_params"}: - from litellm.proxy.management_helpers.audit_logs import ( - reset_audit_log_callback_cache, - ) + from litellm.integrations.s3_v2 import S3Logger as S3V2Logger from litellm.litellm_core_utils.litellm_logging import ( _in_memory_loggers, ) - from litellm.integrations.s3_v2 import S3Logger as S3V2Logger + from litellm.proxy.management_helpers.audit_logs import ( + reset_audit_log_callback_cache, + ) reset_audit_log_callback_cache() _in_memory_loggers[:] = [ @@ -5335,7 +5342,7 @@ class ProxyConfig: 4. Update router settings """ if llm_router is not None and prisma_client is not None: - db_router_settings = await prisma_client.db.litellm_config.find_first( + db_router_settings = await ConfigRepository(prisma_client).table.find_first( where={"param_name": "router_settings"} ) @@ -5761,7 +5768,7 @@ class ProxyConfig: async def _get_models_from_db(self, prisma_client: PrismaClient) -> list: try: - new_models = await prisma_client.db.litellm_proxymodeltable.find_many() + new_models = await ModelRepository(prisma_client).table.find_many() except Exception as e: verbose_proxy_logger.exception( "litellm.proxy_server.py::add_deployment() - Error getting new models from DB - {}".format( @@ -5975,7 +5982,7 @@ class ProxyConfig: """ try: - sso_settings = await prisma_client.db.litellm_ssoconfig.find_unique( + sso_settings = await SSOConfigRepository(prisma_client).table.find_unique( where={"id": "sso_config"} ) if sso_settings is not None: @@ -6011,9 +6018,9 @@ class ProxyConfig: ) try: - db_record = await prisma_client.db.litellm_configoverrides.find_unique( - where={"config_type": "hashicorp_vault"} - ) + db_record = await ConfigOverridesRepository( + prisma_client + ).table.find_unique(where={"config_type": "hashicorp_vault"}) if db_record is None or db_record.config_value is None: if self._last_hashicorp_vault_config is not None: @@ -6130,7 +6137,7 @@ class ProxyConfig: last_model_cost_map_reload = current_time.isoformat() # Clear force reload flag in database - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": "model_cost_map_reload_config"}, data={ "create": { @@ -6239,7 +6246,7 @@ class ProxyConfig: last_anthropic_beta_headers_reload = current_time.isoformat() # Clear force reload flag in database - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": "anthropic_beta_headers_reload_config"}, data={ "create": { @@ -6299,7 +6306,7 @@ class ProxyConfig: from litellm.types.prompts.init_prompts import PromptSpec try: - prompts_in_db = await prisma_client.db.litellm_prompttable.find_many() + prompts_in_db = await PromptRepository(prisma_client).table.find_many() for prompt in prompts_in_db: # Convert DB object to dict and create versioned prompt_id prompt_spec = self._get_prompt_spec_for_db_prompt(db_prompt=prompt) @@ -6589,7 +6596,7 @@ class ProxyConfig: async def get_credentials(self, prisma_client: PrismaClient): try: - credentials = await prisma_client.db.litellm_credentialstable.find_many() + credentials = await CredentialsRepository(prisma_client).find_all() credentials = [self.decrypt_credentials(cred) for cred in credentials] await self.delete_credentials( credentials @@ -7352,7 +7359,7 @@ class ProxyStartupEvent: # spend cap blocks forever once it's hit. if prisma_client is not None and litellm.budget_duration is not None: try: - await prisma_client.db.litellm_usertable.update_many( + await UserRepository(prisma_client).table.update_many( where={ "user_id": litellm_proxy_budget_name, "budget_reset_at": None, @@ -7420,7 +7427,7 @@ class ProxyStartupEvent: if prisma_client is None: return - db_record = await prisma_client.db.litellm_uisettings.find_unique( + db_record = await UISettingsRepository(prisma_client).table.find_unique( where={"id": "ui_settings"} ) if db_record and db_record.ui_settings: @@ -7569,7 +7576,7 @@ class ProxyStartupEvent: # but YAML config has False. if store_model_in_db is not True and prisma_client is not None: try: - _db_gs_record = await prisma_client.db.litellm_config.find_first( + _db_gs_record = await ConfigRepository(prisma_client).table.find_first( where={"param_name": "general_settings"} ) if _db_gs_record is not None and isinstance( @@ -10361,6 +10368,18 @@ async def run_thread( # ) # async def get_available_routes(user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)): from litellm.llms.base_llm.base_utils import BaseTokenCounter +from litellm.repositories.config_repository import ConfigRepository +from litellm.repositories.model_repository import ModelRepository +from litellm.repositories.table_repositories import ( + AccessGroupRepository, + ConfigOverridesRepository, + InvitationLinkRepository, + PromptRepository, + SSOConfigRepository, + UISettingsRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository def _get_provider_token_counter( @@ -10666,7 +10685,7 @@ async def _check_if_model_is_user_added( id = model.get("model_info", {}).get("id", None) if id is None: continue - db_model = await prisma_client.db.litellm_proxymodeltable.find_unique( + db_model = await ModelRepository(prisma_client).table.find_unique( where={"model_id": id} ) if db_model is not None: @@ -10723,7 +10742,7 @@ async def non_admin_all_models( if user_api_key_dict.user_id: try: - user_row = await prisma_client.db.litellm_usertable.find_unique( + user_row = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_api_key_dict.user_id} ) except Exception: @@ -10823,7 +10842,7 @@ async def _add_access_group_models_to_team_models( return team_models # Single batch fetch for all access groups - access_group_rows = await prisma_client.db.litellm_accessgrouptable.find_many( + access_group_rows = await AccessGroupRepository(prisma_client).table.find_many( where={"access_group_id": {"in": list(all_access_group_ids)}} ) ag_model_map: Dict[str, List[str]] = { @@ -10865,13 +10884,13 @@ async def get_all_team_models( team_db_objects_typed: List[LiteLLM_TeamTable] = [] if user_teams == "*": - team_db_objects = await prisma_client.db.litellm_teamtable.find_many() + team_db_objects = await TeamRepository(prisma_client).table.find_many() team_db_objects_typed = [ LiteLLM_TeamTable(**team_db_object.model_dump()) for team_db_object in team_db_objects ] else: - team_db_objects = await prisma_client.db.litellm_teamtable.find_many( + team_db_objects = await TeamRepository(prisma_client).table.find_many( where={"team_id": {"in": user_teams}} ) @@ -10938,7 +10957,7 @@ async def get_all_team_and_direct_access_models( exclude_team_models=True ) # has access to all models elif user_api_key_dict.user_id is not None: - user_db_object = await prisma_client.db.litellm_usertable.find_unique( + user_db_object = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_api_key_dict.user_id} ) if user_db_object is not None: @@ -11082,7 +11101,7 @@ async def _get_caller_byok_team_scope( if user_id is None: return set() try: - user_row = await prisma_client.db.litellm_usertable.find_unique( + user_row = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_id} ) except Exception: @@ -11143,13 +11162,13 @@ async def _fetch_db_models_for_search( else: take_limit = max(0, page * size - router_models_count) - db_models_total_count = await prisma_client.db.litellm_proxymodeltable.count( + db_models_total_count = await ModelRepository(prisma_client).table.count( where=db_where_condition ) db_models_raw: list = [] if take_limit > 0: - db_models_raw = await prisma_client.db.litellm_proxymodeltable.find_many( + db_models_raw = await ModelRepository(prisma_client).table.find_many( where=db_where_condition, take=take_limit, ) @@ -11484,7 +11503,7 @@ async def _load_team_object_for_model_filter( ) -> Optional[LiteLLM_TeamTable]: """Load team row from DB; returns None if missing or on error.""" try: - team_db_object = await prisma_client.db.litellm_teamtable.find_unique( + team_db_object = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) if team_db_object is None: @@ -11547,7 +11566,7 @@ async def _gather_team_accessible_model_ids( _resolved_names = _team_models_resolve_to_names( team_object.models, access_groups ) - db_models = await prisma_client.db.litellm_proxymodeltable.find_many( + db_models = await ModelRepository(prisma_client).table.find_many( where={"model_name": {"in": _resolved_names}} ) for db_model in db_models: @@ -11586,7 +11605,7 @@ async def _authorize_team_id_query( detail={"error": "Not authorized to view this team's models"}, ) try: - user_row = await prisma_client.db.litellm_usertable.find_unique( + user_row = await UserRepository(prisma_client).table.find_unique( where={"user_id": user_id} ) except Exception: @@ -11693,7 +11712,7 @@ async def _find_model_by_id( # If not found in config, search in database if found_model is None: try: - db_model = await prisma_client.db.litellm_proxymodeltable.find_unique( + db_model = await ModelRepository(prisma_client).table.find_unique( where={"model_id": model_id} ) if db_model: @@ -12845,7 +12864,7 @@ async def alerting_settings( ) ## get general settings from db - db_general_settings = await prisma_client.db.litellm_config.find_first( + db_general_settings = await ConfigRepository(prisma_client).table.find_first( where={"param_name": "general_settings"} ) @@ -13384,7 +13403,7 @@ async def onboarding(invite_link: str, request: Request): detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - invite_obj = await prisma_client.db.litellm_invitationlink.find_unique( + invite_obj = await InvitationLinkRepository(prisma_client).table.find_unique( where={"id": invite_link} ) if invite_obj is None: @@ -13408,7 +13427,7 @@ async def onboarding(invite_link: str, request: Request): ) ### GET USER OBJECT ### - user_obj = await prisma_client.db.litellm_usertable.find_unique( + user_obj = await UserRepository(prisma_client).table.find_unique( where={"user_id": invite_obj.user_id} ) @@ -13513,7 +13532,7 @@ async def _rollback_onboarding_invite_claim( return try: - await prisma_client.db.litellm_invitationlink.update_many( + await InvitationLinkRepository(prisma_client).table.update_many( where={"id": invitation_link, "is_accepted": True}, data={ "accepted_at": None, @@ -13547,10 +13566,10 @@ async def _generate_onboarding_ui_session_token(user_obj: Any) -> str: ) key = response["token"] # type: ignore - from litellm.types.proxy.ui_sso import ReturnedUITokenObject - import jwt + from litellm.types.proxy.ui_sso import ReturnedUITokenObject + disabled_non_admin_personal_key_creation = ( get_disabled_non_admin_personal_key_creation() ) @@ -13596,7 +13615,7 @@ async def claim_onboarding_link(data: InvitationClaim, request: Request): detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - invite_obj = await prisma_client.db.litellm_invitationlink.find_unique( + invite_obj = await InvitationLinkRepository(prisma_client).table.find_unique( where={"id": data.invitation_link} ) if invite_obj is None: @@ -13956,7 +13975,7 @@ async def invitation_info( }, ) - response = await prisma_client.db.litellm_invitationlink.find_unique( + response = await InvitationLinkRepository(prisma_client).table.find_unique( where={"id": invitation_id} ) @@ -14010,7 +14029,7 @@ async def invitation_update( ) current_time = litellm.utils.get_utc_datetime() - response = await prisma_client.db.litellm_invitationlink.update( + response = await InvitationLinkRepository(prisma_client).table.update( where={"id": data.invitation_id}, data={ "id": data.invitation_id, @@ -14081,7 +14100,7 @@ async def invitation_delete( # Org admins can only delete invitations they created if is_other_admin and not is_proxy_admin: - invitation = await prisma_client.db.litellm_invitationlink.find_unique( + invitation = await InvitationLinkRepository(prisma_client).table.find_unique( where={"id": data.invitation_id} ) if invitation is None: @@ -14097,7 +14116,7 @@ async def invitation_delete( }, ) - response = await prisma_client.db.litellm_invitationlink.delete( + response = await InvitationLinkRepository(prisma_client).table.delete( where={"id": data.invitation_id} ) @@ -14139,7 +14158,7 @@ async def update_config( # noqa: PLR0915 raise Exception("No DB Connected") async def _read_section(param_name: str) -> dict: - row = await prisma_client.db.litellm_config.find_first( + row = await ConfigRepository(prisma_client).table.find_first( where={"param_name": param_name} ) if row is None or row.param_value is None: @@ -14148,7 +14167,7 @@ async def update_config( # noqa: PLR0915 async def _upsert_section(param_name: str, value: dict) -> None: serialized = json.dumps(value) - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": param_name}, data={ "create": {"param_name": param_name, "param_value": serialized}, @@ -14323,7 +14342,7 @@ async def update_config_general_settings( ) ## get general settings from db - db_general_settings = await prisma_client.db.litellm_config.find_first( + db_general_settings = await ConfigRepository(prisma_client).table.find_first( where={"param_name": "general_settings"} ) ### update value @@ -14337,7 +14356,7 @@ async def update_config_general_settings( general_settings[data.field_name] = data.field_value - response = await prisma_client.db.litellm_config.upsert( + response = await ConfigRepository(prisma_client).table.upsert( where={"param_name": "general_settings"}, data={ "create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore @@ -14387,7 +14406,7 @@ async def get_config_general_settings( ) ## get general settings from db - db_general_settings = await prisma_client.db.litellm_config.find_first( + db_general_settings = await ConfigRepository(prisma_client).table.find_first( where={"param_name": "general_settings"} ) ### pop the value @@ -14450,7 +14469,7 @@ async def get_config_list( ) ## get general settings from db - db_general_settings = await prisma_client.db.litellm_config.find_first( + db_general_settings = await ConfigRepository(prisma_client).table.find_first( where={"param_name": "general_settings"} ) @@ -14604,7 +14623,7 @@ async def delete_config_general_settings( ) ## get general settings from db - db_general_settings = await prisma_client.db.litellm_config.find_first( + db_general_settings = await ConfigRepository(prisma_client).table.find_first( where={"param_name": "general_settings"} ) ### pop the value @@ -14621,7 +14640,7 @@ async def delete_config_general_settings( general_settings.pop(data.field_name, None) - response = await prisma_client.db.litellm_config.upsert( + response = await ConfigRepository(prisma_client).table.upsert( where={"param_name": "general_settings"}, data={ "create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore @@ -14975,14 +14994,14 @@ async def reload_model_cost_map( last_model_cost_map_reload = current_time.isoformat() # Set force reload flag in database for other pods, preserving existing interval_hours - existing_config = await prisma_client.db.litellm_config.find_unique( + existing_config = await ConfigRepository(prisma_client).table.find_unique( where={"param_name": "model_cost_map_reload_config"} ) existing_interval = None if existing_config and existing_config.param_value: existing_interval = existing_config.param_value.get("interval_hours") - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": "model_cost_map_reload_config"}, data={ "create": { @@ -15052,7 +15071,7 @@ async def schedule_model_cost_map_reload( ) # Update database with new reload configuration - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": "model_cost_map_reload_config"}, data={ "create": { @@ -15119,7 +15138,7 @@ async def cancel_model_cost_map_reload( ) # Remove reload configuration from database - await prisma_client.db.litellm_config.delete( + await ConfigRepository(prisma_client).table.delete( where={"param_name": "model_cost_map_reload_config"} ) await invalidate_config_param("model_cost_map_reload_config") @@ -15178,7 +15197,7 @@ async def get_model_cost_map_reload_status( } # Get reload configuration from database - config_record = await prisma_client.db.litellm_config.find_unique( + config_record = await ConfigRepository(prisma_client).table.find_unique( where={"param_name": "model_cost_map_reload_config"} ) @@ -15329,7 +15348,7 @@ async def reload_anthropic_beta_headers( last_anthropic_beta_headers_reload = current_time.isoformat() # Set force reload flag in database for other pods, preserving existing interval_hours - existing_beta_config = await prisma_client.db.litellm_config.find_unique( + existing_beta_config = await ConfigRepository(prisma_client).table.find_unique( where={"param_name": "anthropic_beta_headers_reload_config"} ) existing_beta_interval = None @@ -15338,7 +15357,7 @@ async def reload_anthropic_beta_headers( "interval_hours" ) - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": "anthropic_beta_headers_reload_config"}, data={ "create": { @@ -15412,7 +15431,7 @@ async def schedule_anthropic_beta_headers_reload( ) # Update database with new reload configuration - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": "anthropic_beta_headers_reload_config"}, data={ "create": { @@ -15479,7 +15498,7 @@ async def cancel_anthropic_beta_headers_reload( ) # Remove reload configuration from database - await prisma_client.db.litellm_config.delete( + await ConfigRepository(prisma_client).table.delete( where={"param_name": "anthropic_beta_headers_reload_config"} ) await invalidate_config_param("anthropic_beta_headers_reload_config") @@ -15539,7 +15558,7 @@ async def get_anthropic_beta_headers_reload_status( } # Get reload configuration from database - config_record = await prisma_client.db.litellm_config.find_unique( + config_record = await ConfigRepository(prisma_client).table.find_unique( where={"param_name": "anthropic_beta_headers_reload_config"} ) diff --git a/litellm/proxy/public_endpoints/public_endpoints.py b/litellm/proxy/public_endpoints/public_endpoints.py index d12e7a35fb..78467c4b2e 100644 --- a/litellm/proxy/public_endpoints/public_endpoints.py +++ b/litellm/proxy/public_endpoints/public_endpoints.py @@ -4,9 +4,9 @@ import re from importlib.resources import files from typing import Any, Dict, List, Optional -import litellm from fastapi import APIRouter, HTTPException, Request +import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.get_blog_posts import ( BlogPost, @@ -17,6 +17,7 @@ from litellm.litellm_core_utils.get_blog_posts import ( from litellm.proxy._types import ( CommonProxyErrors, ) +from litellm.repositories.table_repositories import ClaudeCodePluginRepository from litellm.types.agents import AgentCard from litellm.types.mcp import MCPPublicServer from litellm.types.proxy.management_endpoints.model_management_endpoints import ( @@ -159,14 +160,14 @@ def _load_endpoints() -> List[Dict[str, Any]]: ) async def public_model_hub(): import litellm + from litellm.proxy.health_endpoints._health_endpoints import ( + _convert_health_check_to_dict, + ) from litellm.proxy.proxy_server import ( _get_model_group_info, llm_router, prisma_client, ) - from litellm.proxy.health_endpoints._health_endpoints import ( - _convert_health_check_to_dict, - ) if llm_router is None: raise HTTPException( @@ -266,7 +267,7 @@ async def public_skill_hub(): try: prisma_client = await _get_prisma_client() - plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many( + plugins = await ClaudeCodePluginRepository(prisma_client).table.find_many( where={"enabled": True} ) items = [] diff --git a/litellm/proxy/rag_endpoints/endpoints.py b/litellm/proxy/rag_endpoints/endpoints.py index a44e478149..7ff54ac4c5 100644 --- a/litellm/proxy/rag_endpoints/endpoints.py +++ b/litellm/proxy/rag_endpoints/endpoints.py @@ -17,16 +17,17 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH from litellm.proxy._types import * +from litellm.proxy.auth.auth_utils import is_request_body_safe from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth from litellm.proxy.common_utils.http_parsing_utils import ( _read_request_body, _safe_get_request_headers, get_form_data, ) -from litellm.proxy.auth.auth_utils import is_request_body_safe from litellm.proxy.vector_store_endpoints.utils import ( assert_user_can_access_vector_store_id, ) +from litellm.repositories.table_repositories import ManagedVectorStoresRepository router = APIRouter() @@ -230,11 +231,9 @@ async def _save_vector_store_to_db_from_rag_ingest( try: # Check if vector store already exists in database - existing_vector_store = ( - await prisma_client.db.litellm_managedvectorstorestable.find_unique( - where={"vector_store_id": vector_store_id} - ) - ) + existing_vector_store = await ManagedVectorStoresRepository( + prisma_client + ).table.find_unique(where={"vector_store_id": vector_store_id}) # Only create if it doesn't exist if existing_vector_store is None: @@ -289,7 +288,7 @@ async def _save_vector_store_to_db_from_rag_ingest( # Update the vector store from litellm.proxy.utils import safe_dumps - await prisma_client.db.litellm_managedvectorstorestable.update( + await ManagedVectorStoresRepository(prisma_client).table.update( where={"vector_store_id": vector_store_id}, data={"vector_store_metadata": safe_dumps(existing_metadata)}, ) diff --git a/litellm/proxy/search_endpoints/search_tool_registry.py b/litellm/proxy/search_endpoints/search_tool_registry.py index d4adc2573e..588d71b77f 100644 --- a/litellm/proxy/search_endpoints/search_tool_registry.py +++ b/litellm/proxy/search_endpoints/search_tool_registry.py @@ -8,6 +8,7 @@ from typing import List, Optional from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy.utils import PrismaClient +from litellm.repositories.table_repositories import SearchToolsRepository from litellm.types.search import SearchTool @@ -63,16 +64,16 @@ class SearchToolRegistry: search_tool_info: str = safe_dumps(search_tool.get("search_tool_info", {})) # Create search tool in DB - created_search_tool = ( - await prisma_client.db.litellm_searchtoolstable.create( - data={ - "search_tool_name": search_tool_name, - "litellm_params": litellm_params, - "search_tool_info": search_tool_info, - "created_at": datetime.now(timezone.utc), - "updated_at": datetime.now(timezone.utc), - } - ) + created_search_tool = await SearchToolsRepository( + prisma_client + ).table.create( + data={ + "search_tool_name": search_tool_name, + "litellm_params": litellm_params, + "search_tool_info": search_tool_info, + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } ) # Add search_tool_id to the returned search tool object @@ -101,15 +102,15 @@ class SearchToolRegistry: """ try: # Get search tool before deletion for response - existing_tool = await prisma_client.db.litellm_searchtoolstable.find_unique( - where={"search_tool_id": search_tool_id} - ) + existing_tool = await SearchToolsRepository( + prisma_client + ).table.find_unique(where={"search_tool_id": search_tool_id}) if not existing_tool: raise Exception(f"Search tool with ID {search_tool_id} not found") # Delete from DB - await prisma_client.db.litellm_searchtoolstable.delete( + await SearchToolsRepository(prisma_client).table.delete( where={"search_tool_id": search_tool_id} ) @@ -145,16 +146,16 @@ class SearchToolRegistry: search_tool_info: str = safe_dumps(search_tool.get("search_tool_info", {})) # Update in DB - updated_search_tool = ( - await prisma_client.db.litellm_searchtoolstable.update( - where={"search_tool_id": search_tool_id}, - data={ - "search_tool_name": search_tool_name, - "litellm_params": litellm_params, - "search_tool_info": search_tool_info, - "updated_at": datetime.now(timezone.utc), - }, - ) + updated_search_tool = await SearchToolsRepository( + prisma_client + ).table.update( + where={"search_tool_id": search_tool_id}, + data={ + "search_tool_name": search_tool_name, + "litellm_params": litellm_params, + "search_tool_info": search_tool_info, + "updated_at": datetime.now(timezone.utc), + }, ) # Convert to dict with ISO formatted datetimes @@ -179,10 +180,10 @@ class SearchToolRegistry: List of search tool configurations """ try: - search_tools_from_db = ( - await prisma_client.db.litellm_searchtoolstable.find_many( - order={"created_at": "desc"}, - ) + search_tools_from_db = await SearchToolsRepository( + prisma_client + ).table.find_many( + order={"created_at": "desc"}, ) search_tools: List[SearchTool] = [] @@ -214,7 +215,7 @@ class SearchToolRegistry: Search tool configuration or None if not found """ try: - search_tool = await prisma_client.db.litellm_searchtoolstable.find_unique( + search_tool = await SearchToolsRepository(prisma_client).table.find_unique( where={"search_tool_id": search_tool_id} ) @@ -244,7 +245,7 @@ class SearchToolRegistry: Search tool configuration or None if not found """ try: - search_tool = await prisma_client.db.litellm_searchtoolstable.find_unique( + search_tool = await SearchToolsRepository(prisma_client).table.find_unique( where={"search_tool_name": search_tool_name} ) diff --git a/litellm/proxy/spend_tracking/cloudzero_endpoints.py b/litellm/proxy/spend_tracking/cloudzero_endpoints.py index 1f551d5ffe..71f4a8af11 100644 --- a/litellm/proxy/spend_tracking/cloudzero_endpoints.py +++ b/litellm/proxy/spend_tracking/cloudzero_endpoints.py @@ -6,11 +6,12 @@ from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view from litellm.proxy.common_utils.encrypt_decrypt_utils import ( decrypt_value_helper, encrypt_value_helper, ) +from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view +from litellm.repositories.config_repository import ConfigRepository from litellm.types.proxy.cloudzero_endpoints import ( CloudZeroExportRequest, CloudZeroExportResponse, @@ -53,7 +54,7 @@ async def _set_cloudzero_settings(api_key: str, connection_id: str, timezone: st "timezone": timezone, } - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": "cloudzero_settings"}, data={ "create": { @@ -80,7 +81,7 @@ async def _get_cloudzero_settings(): detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - cloudzero_config = await prisma_client.db.litellm_config.find_first( + cloudzero_config = await ConfigRepository(prisma_client).table.find_first( where={"param_name": "cloudzero_settings"} ) if cloudzero_config is None or cloudzero_config.param_value is None: @@ -282,7 +283,7 @@ async def is_cloudzero_setup_in_db() -> bool: return False # Check for CloudZero settings in database - cloudzero_config = await prisma_client.db.litellm_config.find_first( + cloudzero_config = await ConfigRepository(prisma_client).table.find_first( where={"param_name": "cloudzero_settings"} ) @@ -548,7 +549,7 @@ async def delete_cloudzero_settings( ) # Check if CloudZero settings exist - cloudzero_config = await prisma_client.db.litellm_config.find_first( + cloudzero_config = await ConfigRepository(prisma_client).table.find_first( where={"param_name": "cloudzero_settings"} ) @@ -560,7 +561,7 @@ async def delete_cloudzero_settings( # Delete only the CloudZero settings entry # This uses a specific where clause to target only the cloudzero_settings row - await prisma_client.db.litellm_config.delete( + await ConfigRepository(prisma_client).table.delete( where={"param_name": "cloudzero_settings"} ) diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index ca5e047365..f651e6e5f7 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -21,6 +21,11 @@ from litellm.proxy.spend_tracking.spend_tracking_utils import ( get_spend_by_team_and_customer, ) from litellm.proxy.utils import handle_exception_on_proxy +from litellm.repositories.table_repositories import SpendLogsRepository +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) if TYPE_CHECKING: from litellm.proxy.proxy_server import PrismaClient @@ -2010,7 +2015,7 @@ async def ui_view_spend_logs( # noqa: PLR0915 order_direction = (sort_order or "desc").lower() # Get total count of records - total_records = await prisma_client.db.litellm_spendlogs.count( + total_records = await SpendLogsRepository(prisma_client).table.count( where=where_conditions, ) @@ -2374,7 +2379,7 @@ async def view_spend_logs( # noqa: PLR0915 # Check if user wants unsummarized data if not summarize: # Return filtered individual log entries (similar to UI endpoint) - data = await prisma_client.db.litellm_spendlogs.find_many( + data = await SpendLogsRepository(prisma_client).table.find_many( where=filter_query, # type: ignore order={ "startTime": "desc", @@ -2384,7 +2389,7 @@ async def view_spend_logs( # noqa: PLR0915 # Legacy behavior: return summarized data (when summarize=true) # SQL query - response = await prisma_client.db.litellm_spendlogs.group_by( + response = await SpendLogsRepository(prisma_client).table.group_by( by=["api_key", "user", "model", "startTime"], where=filter_query, # type: ignore sum={ @@ -2462,7 +2467,7 @@ async def view_spend_logs( # noqa: PLR0915 ) return spend_logs - data = await prisma_client.db.litellm_spendlogs.find_many( + data = await SpendLogsRepository(prisma_client).table.find_many( where=scoped_filter, # type: ignore order={"startTime": "desc"}, ) @@ -2514,10 +2519,10 @@ async def global_spend_reset(): code=status.HTTP_401_UNAUTHORIZED, ) - await prisma_client.db.litellm_verificationtoken.update_many( + await VerificationTokenRepository(prisma_client).table.update_many( data={"spend": 0.0}, where={} ) - await prisma_client.db.litellm_teamtable.update_many(data={"spend": 0.0}, where={}) + await TeamRepository(prisma_client).table.update_many(data={"spend": 0.0}, where={}) return { "message": "Spend for all API Keys and Teams reset successfully", @@ -3384,7 +3389,7 @@ async def ui_view_session_spend_logs( skip = (page - 1) * page_size # Get total count for pagination metadata - total_records = await prisma_client.db.litellm_spendlogs.count( + total_records = await SpendLogsRepository(prisma_client).table.count( where=where_conditions ) @@ -3485,7 +3490,7 @@ async def _build_ui_spend_logs_response( # is bounded by page_size (typically 25-50 distinct session IDs). # If performance degrades at scale, consider short-lived caching or # folding the count into the main query via a window function. - counts = await prisma_client.db.litellm_spendlogs.group_by( + counts = await SpendLogsRepository(prisma_client).table.group_by( by=["session_id"], where={"session_id": {"in": session_ids}}, count={"session_id": True}, @@ -3572,7 +3577,7 @@ async def _can_team_member_view_log( if team_id is None: return False - team_row = await prisma_client.db.litellm_teamtable.find_unique( + team_row = await TeamRepository(prisma_client).table.find_unique( where={"team_id": team_id} ) if team_row is None: @@ -3614,7 +3619,7 @@ async def _assert_user_can_view_request_id( permitted teams (admin or ``/spend/logs`` permission). Raises HTTP 403 if not. """ - row = await prisma_client.db.litellm_spendlogs.find_unique( + row = await SpendLogsRepository(prisma_client).table.find_unique( where={"request_id": request_id}, include=None, ) @@ -3669,7 +3674,7 @@ async def _get_permitted_team_ids_for_spend_logs( if user_obj is None or not user_obj.teams: return [] - team_rows = await prisma_client.db.litellm_teamtable.find_many( + team_rows = await TeamRepository(prisma_client).table.find_many( where={"team_id": {"in": user_obj.teams}} ) diff --git a/litellm/proxy/spend_tracking/vantage_endpoints.py b/litellm/proxy/spend_tracking/vantage_endpoints.py index 60e54d005b..1dde31b54c 100644 --- a/litellm/proxy/spend_tracking/vantage_endpoints.py +++ b/litellm/proxy/spend_tracking/vantage_endpoints.py @@ -1,17 +1,18 @@ import json -import litellm from fastapi import APIRouter, Depends, HTTPException +import litellm from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view from litellm.proxy.common_utils.encrypt_decrypt_utils import ( decrypt_value_helper, encrypt_value_helper, ) +from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view +from litellm.repositories.config_repository import ConfigRepository from litellm.types.proxy.vantage_endpoints import ( VantageDryRunRequest, VantageExportRequest, @@ -60,7 +61,7 @@ async def _set_vantage_settings(api_key: str, integration_token: str, base_url: "base_url": base_url, } - await prisma_client.db.litellm_config.upsert( + await ConfigRepository(prisma_client).table.upsert( where={"param_name": VANTAGE_SETTINGS_PARAM_NAME}, data={ "create": { @@ -82,7 +83,7 @@ async def _get_vantage_settings(): detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - vantage_config = await prisma_client.db.litellm_config.find_first( + vantage_config = await ConfigRepository(prisma_client).table.find_first( where={"param_name": VANTAGE_SETTINGS_PARAM_NAME} ) if vantage_config is None or vantage_config.param_value is None: @@ -265,7 +266,7 @@ async def is_vantage_setup_in_db() -> bool: if prisma_client is None: return False - vantage_config = await prisma_client.db.litellm_config.find_first( + vantage_config = await ConfigRepository(prisma_client).table.find_first( where={"param_name": VANTAGE_SETTINGS_PARAM_NAME} ) @@ -553,7 +554,7 @@ async def delete_vantage_settings( detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) - vantage_config = await prisma_client.db.litellm_config.find_first( + vantage_config = await ConfigRepository(prisma_client).table.find_first( where={"param_name": VANTAGE_SETTINGS_PARAM_NAME} ) @@ -563,7 +564,7 @@ async def delete_vantage_settings( detail={"error": "Vantage settings not found"}, ) - await prisma_client.db.litellm_config.delete( + await ConfigRepository(prisma_client).table.delete( where={"param_name": VANTAGE_SETTINGS_PARAM_NAME} ) diff --git a/litellm/proxy/ui_crud_endpoints/proxy_setting_endpoints.py b/litellm/proxy/ui_crud_endpoints/proxy_setting_endpoints.py index 07e2ca7195..ea634289cb 100644 --- a/litellm/proxy/ui_crud_endpoints/proxy_setting_endpoints.py +++ b/litellm/proxy/ui_crud_endpoints/proxy_setting_endpoints.py @@ -12,6 +12,12 @@ from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.sensitive_data_masker import mask_sensitive_keys from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.repositories.config_repository import ConfigRepository +from litellm.repositories.table_repositories import ( + DailyTagSpendRepository, + SSOConfigRepository, + UISettingsRepository, +) from litellm.types.proxy.management_endpoints.ui_sso import ( DefaultTeamSSOParams, InProductNudgeResponse, @@ -665,7 +671,7 @@ async def get_sso_settings(): ) # Get SSO config from dedicated table - sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique( + sso_db_record = await SSOConfigRepository(prisma_client).table.find_unique( where={"id": "sso_config"} ) @@ -836,7 +842,7 @@ async def update_sso_settings(sso_config: SSOConfig): ) # Save to dedicated SSO table - await prisma_client.db.litellm_ssoconfig.upsert( + await SSOConfigRepository(prisma_client).table.upsert( where={"id": "sso_config"}, data={ "create": { @@ -851,7 +857,7 @@ async def update_sso_settings(sso_config: SSOConfig): # Remove SSO-related env vars from config.environment_variables try: - env_var_entry = await prisma_client.db.litellm_config.find_unique( + env_var_entry = await ConfigRepository(prisma_client).table.find_unique( where={"param_name": "environment_variables"} ) @@ -872,7 +878,7 @@ async def update_sso_settings(sso_config: SSOConfig): if key not in env_vars_to_remove } - await prisma_client.db.litellm_config.update( + await ConfigRepository(prisma_client).table.update( where={"param_name": "environment_variables"}, data={ "param_value": json.dumps(filtered_env_vars, default=str), @@ -1123,7 +1129,7 @@ async def get_in_product_nudges(): detail={"error": "Database not connected. Please connect a database."}, ) - db_record = await prisma_client.db.litellm_dailytagspend.find_first( + db_record = await DailyTagSpendRepository(prisma_client).table.find_first( where={"tag": "User-Agent: claude-cli"} ) @@ -1155,7 +1161,7 @@ async def get_ui_settings_cached() -> Dict[str, Any]: if prisma_client is None: return {} - db_record = await prisma_client.db.litellm_uisettings.find_unique( + db_record = await UISettingsRepository(prisma_client).table.find_unique( where={"id": "ui_settings"} ) ui_settings: Dict[str, Any] = {} @@ -1196,7 +1202,7 @@ async def get_ui_settings(): ui_settings: Dict[str, Any] = {} - db_record = await prisma_client.db.litellm_uisettings.find_unique( + db_record = await UISettingsRepository(prisma_client).table.find_unique( where={"id": "ui_settings"} ) @@ -1309,7 +1315,7 @@ async def update_ui_settings( # Merge with existing persisted settings so a partial PATCH doesn't # overwrite fields the caller didn't send. existing: dict = {} - db_existing = await prisma_client.db.litellm_uisettings.find_unique( + db_existing = await UISettingsRepository(prisma_client).table.find_unique( where={"id": "ui_settings"} ) if db_existing and db_existing.ui_settings: @@ -1318,7 +1324,7 @@ async def update_ui_settings( ui_settings = {**existing, **incoming} - await prisma_client.db.litellm_uisettings.upsert( + await UISettingsRepository(prisma_client).table.upsert( where={"id": "ui_settings"}, data={ "create": { diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index e77e24c9e7..5ad42b5e1b 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -9,10 +9,10 @@ import sys import threading import time import traceback +from dataclasses import dataclass, field from datetime import date, datetime, timedelta, timezone from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, Any, @@ -139,6 +139,19 @@ from litellm.proxy.hooks.parallel_request_limiter import ( ) from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup from litellm.proxy.policy_engine.pipeline_executor import PipelineExecutor +from litellm.repositories.budget_repository import BudgetRepository +from litellm.repositories.config_repository import ConfigRepository +from litellm.repositories.table_repositories import ( + EndUserRepository, + HealthCheckRepository, + SpendLogsRepository, + UserNotificationsRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) from litellm.secret_managers.main import str_to_bool from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES from litellm.types.mcp import ( @@ -2831,7 +2844,7 @@ async def prefetch_config_params(prisma_client: Any, param_names: List[str]) -> if not param_names: return try: - rows = await prisma_client.db.litellm_config.find_many( + rows = await ConfigRepository(prisma_client).table.find_many( where={"param_name": {"in": param_names}} # type: ignore ) except Exception as e: @@ -3194,15 +3207,15 @@ class PrismaClient: async def _do_query(): if table_name == "users": - return await self.db.litellm_usertable.find_first( + return await UserRepository(self).table.find_first( where={key: value} # type: ignore ) elif table_name == "keys": - return await self.db.litellm_verificationtoken.find_first( # type: ignore + return await VerificationTokenRepository(self).table.find_first( # type: ignore where={key: value} # type: ignore ) elif table_name == "config": - return await self.db.litellm_config.find_first( # type: ignore + return await ConfigRepository(self).table.find_first( # type: ignore where={key: value} # type: ignore ) elif table_name == "spend": @@ -3336,7 +3349,9 @@ class PrismaClient: status_code=400, detail={"error": f"No token passed in. Token={token}"}, ) - response = await self.db.litellm_verificationtoken.find_unique( + response = await VerificationTokenRepository( + self + ).table.find_unique( where={"token": hashed_token}, # type: ignore include={"litellm_budget_table": True}, ) @@ -3353,7 +3368,7 @@ class PrismaClient: detail=f"Authentication Error: invalid user key - user key does not exist in db. User Key={token}", ) elif query_type == "find_all" and user_id is not None: - response = await self.db.litellm_verificationtoken.find_many( + response = await VerificationTokenRepository(self).table.find_many( where={"user_id": user_id}, include={"litellm_budget_table": True}, ) @@ -3362,7 +3377,7 @@ class PrismaClient: if isinstance(r.expires, datetime): r.expires = r.expires.isoformat() elif query_type == "find_all" and team_id is not None: - response = await self.db.litellm_verificationtoken.find_many( + response = await VerificationTokenRepository(self).table.find_many( where={"team_id": team_id}, include={"litellm_budget_table": True}, ) @@ -3375,7 +3390,7 @@ class PrismaClient: and expires is not None and reset_at is not None ): - response = await self.db.litellm_verificationtoken.find_many( + response = await VerificationTokenRepository(self).table.find_many( where={ # type: ignore "OR": [ {"expires": None}, @@ -3405,7 +3420,7 @@ class PrismaClient: else: hashed_tokens.append(t) where_filter["token"]["in"] = hashed_tokens - response = await self.db.litellm_verificationtoken.find_many( + response = await VerificationTokenRepository(self).table.find_many( order={"spend": "desc"}, where=where_filter, # type: ignore include={"litellm_budget_table": True}, @@ -3425,28 +3440,28 @@ class PrismaClient: if key_val is None: key_val = {"user_id": user_id} - response = await self.db.litellm_usertable.find_unique( # type: ignore + response = await UserRepository(self).table.find_unique( # type: ignore where=key_val, # type: ignore include={"organization_memberships": True}, ) elif query_type == "find_all" and key_val is not None: - response = await self.db.litellm_usertable.find_many( + response = await UserRepository(self).table.find_many( where=key_val # type: ignore ) # type: ignore elif query_type == "find_all" and reset_at is not None: - response = await self.db.litellm_usertable.find_many( + response = await UserRepository(self).table.find_many( where={ # type: ignore "budget_reset_at": {"lt": reset_at}, } ) elif query_type == "find_all" and user_id_list is not None: - response = await self.db.litellm_usertable.find_many( + response = await UserRepository(self).table.find_many( where={"user_id": {"in": user_id_list}} ) elif query_type == "find_all": if expires is not None: - response = await self.db.litellm_usertable.find_many( # type: ignore + response = await UserRepository(self).table.find_many( # type: ignore order={"spend": "desc"}, where={ # type: ignore "OR": [ @@ -3478,26 +3493,26 @@ class PrismaClient: ) if key_val is not None: if query_type == "find_unique": - response = await self.db.litellm_spendlogs.find_unique( # type: ignore + response = await SpendLogsRepository(self).table.find_unique( # type: ignore where={ # type: ignore key_val["key"]: key_val["value"], # type: ignore } ) elif query_type == "find_all": - response = await self.db.litellm_spendlogs.find_many( # type: ignore + response = await SpendLogsRepository(self).table.find_many( # type: ignore where={ key_val["key"]: key_val["value"], # type: ignore } ) return response else: - response = await self.db.litellm_spendlogs.find_many( # type: ignore + response = await SpendLogsRepository(self).table.find_many( # type: ignore order={"startTime": "desc"}, ) return response elif table_name == "budget" and reset_at is not None: if query_type == "find_all": - response = await self.db.litellm_budgettable.find_many( + response = await BudgetRepository(self).table.find_many( where={ # type: ignore "OR": [ { @@ -3514,45 +3529,45 @@ class PrismaClient: elif table_name == "enduser" and budget_id_list is not None: if query_type == "find_all": - response = await self.db.litellm_endusertable.find_many( + response = await EndUserRepository(self).table.find_many( where={"budget_id": {"in": budget_id_list}} ) return response elif table_name == "team": if query_type == "find_unique": - response = await self.db.litellm_teamtable.find_unique( + response = await TeamRepository(self).table.find_unique( where={"team_id": team_id}, # type: ignore include={"litellm_model_table": True}, # type: ignore ) elif query_type == "find_all" and reset_at is not None: - response = await self.db.litellm_teamtable.find_many( + response = await TeamRepository(self).table.find_many( where={ # type: ignore "budget_reset_at": {"lt": reset_at}, } ) elif query_type == "find_all" and user_id is not None: - response = await self.db.litellm_teamtable.find_many( + response = await TeamRepository(self).table.find_many( where={ "members": {"has": user_id}, }, include={"litellm_budget_table": True}, ) elif query_type == "find_all" and team_id_list is not None: - response = await self.db.litellm_teamtable.find_many( + response = await TeamRepository(self).table.find_many( where={"team_id": {"in": team_id_list}} ) elif query_type == "find_all" and team_id_list is None: - response = await self.db.litellm_teamtable.find_many( + response = await TeamRepository(self).table.find_many( take=MAX_TEAM_LIST_LIMIT ) return response elif table_name == "user_notification": if query_type == "find_unique": - response = await self.db.litellm_usernotifications.find_unique( # type: ignore + response = await UserNotificationsRepository(self).table.find_unique( # type: ignore where={"user_id": user_id} # type: ignore ) elif query_type == "find_all": - response = await self.db.litellm_usernotifications.find_many() # type: ignore + response = await UserNotificationsRepository(self).table.find_many() # type: ignore return response elif table_name == "combined_view": # check if plain text or hash @@ -3744,7 +3759,7 @@ class PrismaClient: print_verbose( "PrismaClient: Before upsert into litellm_verificationtoken" ) - new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore + new_verification_token = await VerificationTokenRepository(self).table.upsert( # type: ignore where={ "token": hashed_token, }, @@ -3759,7 +3774,7 @@ class PrismaClient: elif table_name == "user": db_data = self.jsonify_object(data=data) try: - new_user_row = await self.db.litellm_usertable.upsert( + new_user_row = await UserRepository(self).table.upsert( where={"user_id": data["user_id"]}, data={ "create": {**db_data}, # type: ignore @@ -3782,7 +3797,7 @@ class PrismaClient: return new_user_row elif table_name == "team": db_data = self.jsonify_team_object(db_data=data) - new_team_row = await self.db.litellm_teamtable.upsert( + new_team_row = await TeamRepository(self).table.upsert( where={"team_id": data["team_id"]}, data={ "create": {**db_data}, # type: ignore @@ -3804,7 +3819,7 @@ class PrismaClient: for k, v in data.items(): updated_data = v updated_data = json.dumps(updated_data) - updated_table_row = self.db.litellm_config.upsert( + updated_table_row = ConfigRepository(self).table.upsert( where={"param_name": k}, # type: ignore data={ "create": {"param_name": k, "param_value": updated_data}, # type: ignore @@ -3820,7 +3835,7 @@ class PrismaClient: verbose_proxy_logger.info("Data Inserted into Config Table") elif table_name == "spend": db_data = self.jsonify_object(data=data) - new_spend_row = await self.db.litellm_spendlogs.upsert( + new_spend_row = await SpendLogsRepository(self).table.upsert( where={"request_id": data["request_id"]}, data={ "create": {**db_data}, # type: ignore @@ -3831,14 +3846,14 @@ class PrismaClient: return new_spend_row elif table_name == "user_notification": db_data = self.jsonify_object(data=data) - new_user_notification_row = ( - await self.db.litellm_usernotifications.upsert( # type: ignore - where={"request_id": data["request_id"]}, - data={ - "create": {**db_data}, # type: ignore - "update": {}, # don't do anything if it already exists - }, - ) + new_user_notification_row = await UserNotificationsRepository( + self + ).table.upsert( # type: ignore + where={"request_id": data["request_id"]}, + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, ) verbose_proxy_logger.info("Data Inserted into Model Request Table") return new_user_notification_row @@ -3899,7 +3914,7 @@ class PrismaClient: # check if plain text or hash token = _hash_token_if_needed(token=token) db_data["token"] = token - response = await self.db.litellm_verificationtoken.update( + response = await VerificationTokenRepository(self).table.update( where={"token": token}, # type: ignore data={**db_data}, # type: ignore ) @@ -3930,7 +3945,7 @@ class PrismaClient: update_key_values = update_key_values_custom_query else: update_key_values = db_data - update_user_row = await self.db.litellm_usertable.upsert( + update_user_row = await UserRepository(self).table.upsert( where={"user_id": user_id}, # type: ignore data={ "create": {**db_data}, # type: ignore @@ -3971,7 +3986,7 @@ class PrismaClient: update_key_values["members_with_roles"] = json.dumps( update_key_values["members_with_roles"] ) - update_team_row = await self.db.litellm_teamtable.upsert( + update_team_row = await TeamRepository(self).table.upsert( where={"team_id": team_id}, # type: ignore data={ "create": {**db_data}, # type: ignore @@ -4196,7 +4211,9 @@ class PrismaClient: else: filter_query = {"token": {"in": hashed_tokens}} - deleted_tokens = await self.db.litellm_verificationtoken.delete_many( + deleted_tokens = await VerificationTokenRepository( + self + ).table.delete_many( where=filter_query # type: ignore ) verbose_proxy_logger.debug("deleted_tokens: %s", deleted_tokens) @@ -4207,7 +4224,7 @@ class PrismaClient: and isinstance(team_id_list, List) ): # admin only endpoint -> `/team/delete` - await self.db.litellm_teamtable.delete_many( + await TeamRepository(self).table.delete_many( where={"team_id": {"in": team_id_list}} ) return {"deleted_teams": team_id_list} @@ -4217,7 +4234,7 @@ class PrismaClient: and isinstance(team_id_list, List) ): # admin only endpoint -> `/team/delete` - await self.db.litellm_verificationtoken.delete_many( + await VerificationTokenRepository(self).table.delete_many( where={"team_id": {"in": team_id_list}} ) except Exception as e: @@ -5024,7 +5041,9 @@ class PrismaClient: ) verbose_proxy_logger.debug(f"Saving health check data: {health_check_data}") - return await self.db.litellm_healthchecktable.create(data=health_check_data) + return await HealthCheckRepository(self).table.create( + data=health_check_data + ) except Exception as e: verbose_proxy_logger.error( @@ -5049,7 +5068,7 @@ class PrismaClient: if status_filter: where_clause["status"] = status_filter - results = await self.db.litellm_healthchecktable.find_many( + results = await HealthCheckRepository(self).table.find_many( where=where_clause, order={"checked_at": "desc"}, take=limit, @@ -5068,7 +5087,7 @@ class PrismaClient: (via Prisma ``distinct`` + ``order``) so we never load the full history into memory. """ try: - return await self.db.litellm_healthchecktable.find_many( + return await HealthCheckRepository(self).table.find_many( distinct=["model_id", "model_name"], order=[ {"model_id": "asc"}, @@ -5228,7 +5247,7 @@ async def migrate_passwords_to_scrypt_async(prisma_client) -> str: are left alone (they migrate on next login via the SHA256 fallback). Skips quickly if no plaintext passwords exist. """ - all_with_pw = await prisma_client.db.litellm_usertable.find_many( + all_with_pw = await UserRepository(prisma_client).table.find_many( where={"password": {"not": None}}, ) @@ -5246,7 +5265,7 @@ async def migrate_passwords_to_scrypt_async(prisma_client) -> str: return "No plaintext passwords found" for user in plaintext_users: - await prisma_client.db.litellm_usertable.update( + await UserRepository(prisma_client).table.update( where={"user_id": user.user_id}, data={"password": hash_password(user.password)}, ) @@ -5370,7 +5389,7 @@ class ProxyUpdateSpend: prisma_client.jsonify_object({**entry}) for entry in batch ] - await prisma_client.db.litellm_spendlogs.create_many( + await SpendLogsRepository(prisma_client).table.create_many( data=batch_with_dates, skip_duplicates=True ) verbose_proxy_logger.debug( diff --git a/litellm/proxy/vector_store_endpoints/endpoints.py b/litellm/proxy/vector_store_endpoints/endpoints.py index b3bdfecbe5..9c2d305034 100644 --- a/litellm/proxy/vector_store_endpoints/endpoints.py +++ b/litellm/proxy/vector_store_endpoints/endpoints.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional from fastapi import APIRouter, Depends, HTTPException, Request, Response + from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import ( LiteLLM_ManagedVectorStore, ) @@ -16,6 +17,7 @@ from litellm.proxy.vector_store_endpoints.utils import ( assert_user_can_access_vector_store, get_litellm_managed_vector_store, ) +from litellm.repositories.table_repositories import ManagedVectorStoreIndexRepository from litellm.types.vector_stores import IndexCreateRequest router = APIRouter() @@ -587,11 +589,9 @@ async def index_create( detail=CommonProxyErrors.db_not_connected_error.value, ) ## 1. check if index already exists - existing_index = ( - await prisma_client.db.litellm_managedvectorstoreindextable.find_unique( - where={"index_name": index_create_request.index_name} - ) - ) + existing_index = await ManagedVectorStoreIndexRepository( + prisma_client + ).table.find_unique(where={"index_name": index_create_request.index_name}) ## 2. set created_by and updated_by @@ -605,7 +605,7 @@ async def index_create( index_data = index_create_request.model_dump(exclude_none=True) index_data["created_by"] = user_api_key_dict.user_id index_data["updated_by"] = user_api_key_dict.user_id - new_index = await prisma_client.db.litellm_managedvectorstoreindextable.create( + new_index = await ManagedVectorStoreIndexRepository(prisma_client).table.create( data=jsonify_object(index_data) ) diff --git a/litellm/proxy/vector_store_endpoints/management_endpoints.py b/litellm/proxy/vector_store_endpoints/management_endpoints.py index cbb3d92718..032a3302fd 100644 --- a/litellm/proxy/vector_store_endpoints/management_endpoints.py +++ b/litellm/proxy/vector_store_endpoints/management_endpoints.py @@ -29,6 +29,8 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper from litellm.proxy.common_utils.rbac_utils import check_feature_access_for_user from litellm.proxy.vector_store_endpoints.utils import can_user_access_vector_store +from litellm.repositories.model_repository import ModelRepository +from litellm.repositories.table_repositories import ManagedVectorStoresRepository from litellm.secret_managers.main import get_secret from litellm.types.vector_stores import ( LiteLLM_ManagedVectorStore, @@ -122,7 +124,7 @@ async def _fetch_and_authorize_vector_store( Raises HTTPException(404) on miss and HTTPException(403) on access denial. """ - row = await prisma_client.db.litellm_managedvectorstorestable.find_unique( + row = await ManagedVectorStoresRepository(prisma_client).table.find_unique( where={"vector_store_id": vector_store_id} ) if row is None: @@ -252,7 +254,7 @@ async def _resolve_embedding_config_from_db( # Try to find model in database for model_name in model_name_candidates: try: - db_model = await prisma_client.db.litellm_proxymodeltable.find_first( + db_model = await ModelRepository(prisma_client).table.find_first( where={"model_name": model_name} ) @@ -437,11 +439,9 @@ async def create_vector_store_in_db( raise HTTPException(status_code=500, detail="Database not connected") # Check if vector store already exists - existing_vector_store = ( - await prisma_client.db.litellm_managedvectorstorestable.find_unique( - where={"vector_store_id": vector_store_id} - ) - ) + existing_vector_store = await ManagedVectorStoresRepository( + prisma_client + ).table.find_unique(where={"vector_store_id": vector_store_id}) if existing_vector_store is not None: raise HTTPException( status_code=400, @@ -487,7 +487,7 @@ async def create_vector_store_in_db( data_to_create["litellm_params"] = safe_dumps({}) # Create in database - _new_vector_store = await prisma_client.db.litellm_managedvectorstorestable.create( + _new_vector_store = await ManagedVectorStoresRepository(prisma_client).table.create( data=data_to_create ) @@ -725,11 +725,9 @@ async def delete_vector_store( memory_vector_store_exists = False vector_store_to_check = None - existing_vector_store = ( - await prisma_client.db.litellm_managedvectorstorestable.find_unique( - where={"vector_store_id": data.vector_store_id} - ) - ) + existing_vector_store = await ManagedVectorStoresRepository( + prisma_client + ).table.find_unique(where={"vector_store_id": data.vector_store_id}) if existing_vector_store is not None: db_vector_store_exists = True vector_store_to_check = LiteLLM_ManagedVectorStore( @@ -764,7 +762,7 @@ async def delete_vector_store( # Delete from database if exists if db_vector_store_exists: - await prisma_client.db.litellm_managedvectorstorestable.delete( + await ManagedVectorStoresRepository(prisma_client).table.delete( where={"vector_store_id": data.vector_store_id} ) @@ -921,7 +919,7 @@ async def update_vector_store( update_data["litellm_params"] = safe_dumps(litellm_params_dict) # Update in database - updated = await prisma_client.db.litellm_managedvectorstorestable.update( + updated = await ManagedVectorStoresRepository(prisma_client).table.update( where={"vector_store_id": vector_store_id}, data=update_data, ) diff --git a/litellm/repositories/__init__.py b/litellm/repositories/__init__.py new file mode 100644 index 0000000000..4451f0865d --- /dev/null +++ b/litellm/repositories/__init__.py @@ -0,0 +1,127 @@ +""" +Repository classes for database operations. +""" + +from litellm.repositories.budget_repository import BudgetRepository +from litellm.repositories.config_repository import ConfigRepository +from litellm.repositories.credentials_repository import CredentialsRepository +from litellm.repositories.model_repository import ModelRepository +from litellm.repositories.object_permission_repository import ( + ObjectPermissionRepository, +) +from litellm.repositories.organization_repository import OrganizationRepository +from litellm.repositories.project_repository import ProjectRepository +from litellm.repositories.table_repositories import ( + AccessGroupRepository, + AdaptiveRouterSessionRepository, + AdaptiveRouterStateRepository, + AgentsRepository, + AuditLogRepository, + CacheConfigRepository, + ClaudeCodePluginRepository, + ConfigOverridesRepository, + DailyGuardrailMetricsRepository, + DailyPolicyMetricsRepository, + DailyTagSpendRepository, + DeletedTeamRepository, + DeletedVerificationTokenRepository, + DeprecatedVerificationTokenRepository, + EndUserRepository, + GuardrailsRepository, + HealthCheckRepository, + InvitationLinkRepository, + JWTKeyMappingRepository, + ManagedFileRepository, + ManagedObjectRepository, + ManagedVectorStoreIndexRepository, + ManagedVectorStoresRepository, + MCPServerRepository, + MCPToolsetRepository, + MCPUserCredentialsRepository, + MemoryRepository, + ModelTableRepository, + OrganizationMembershipRepository, + PolicyAttachmentRepository, + PolicyRepository, + PrismaTableRepository, + PromptRepository, + SearchToolsRepository, + SkillsRepository, + SpendLogGuardrailIndexRepository, + SpendLogsRepository, + SpendLogToolIndexRepository, + SSOConfigRepository, + TagRepository, + TeamMembershipRepository, + ToolRepository, + UISettingsRepository, + UserNotificationsRepository, + WorkflowEventRepository, + WorkflowMessageRepository, + WorkflowRunRepository, +) +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) + +__all__ = [ + "PrismaTableRepository", + "PolicyRepository", + "AgentsRepository", + "GuardrailsRepository", + "MCPServerRepository", + "ManagedObjectRepository", + "OrganizationMembershipRepository", + "SpendLogsRepository", + "ClaudeCodePluginRepository", + "TeamMembershipRepository", + "EndUserRepository", + "ManagedVectorStoresRepository", + "MCPUserCredentialsRepository", + "PromptRepository", + "TagRepository", + "InvitationLinkRepository", + "JWTKeyMappingRepository", + "ManagedFileRepository", + "MemoryRepository", + "SearchToolsRepository", + "ConfigOverridesRepository", + "MCPToolsetRepository", + "ToolRepository", + "DeletedVerificationTokenRepository", + "WorkflowRunRepository", + "ModelTableRepository", + "AccessGroupRepository", + "SSOConfigRepository", + "UISettingsRepository", + "DailyGuardrailMetricsRepository", + "PolicyAttachmentRepository", + "DeletedTeamRepository", + "SkillsRepository", + "CacheConfigRepository", + "ManagedVectorStoreIndexRepository", + "WorkflowMessageRepository", + "DailyTagSpendRepository", + "SpendLogToolIndexRepository", + "SpendLogGuardrailIndexRepository", + "UserNotificationsRepository", + "HealthCheckRepository", + "DeprecatedVerificationTokenRepository", + "WorkflowEventRepository", + "DailyPolicyMetricsRepository", + "AdaptiveRouterStateRepository", + "AuditLogRepository", + "AdaptiveRouterSessionRepository", + "BudgetRepository", + "ConfigRepository", + "CredentialsRepository", + "ModelRepository", + "ObjectPermissionRepository", + "OrganizationRepository", + "ProjectRepository", + "TeamRepository", + "UserRepository", + "VerificationTokenRepository", +] diff --git a/litellm/repositories/base_repository.py b/litellm/repositories/base_repository.py new file mode 100644 index 0000000000..a25620c7b4 --- /dev/null +++ b/litellm/repositories/base_repository.py @@ -0,0 +1,117 @@ +""" +Base repository class with common functionality. +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar + +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) + + +def _record_to_dict(record: Any) -> Dict[str, Any]: + if isinstance(record, dict): + return record + if hasattr(record, "model_dump") and callable(record.model_dump): + return record.model_dump() + if hasattr(record, "dict") and callable(record.dict): + return record.dict() + return dict(record) + + +class BaseRepository(ABC, Generic[T]): + """Abstract base class for all repositories.""" + + def __init__(self, prisma_client: Any): + self._prisma_client = prisma_client + + @property + def prisma_client(self) -> Any: + if self._prisma_client is None: + raise RuntimeError( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + return self._prisma_client + + @property + @abstractmethod + def table(self) -> Any: + """Return the Prisma table for this repository.""" + ... + + @property + @abstractmethod + def model_class(self) -> Type[T]: + """Return the domain model class for this repository.""" + ... + + def _to_model(self, record: Any) -> Optional[T]: + """Convert a database record to a domain model.""" + if record is None: + return None + return self.model_class(**_record_to_dict(record)) + + def _to_model_list(self, records: List[Any]) -> List[T]: + """Convert a list of database records to domain models.""" + result: List[T] = [] + for r in records: + if r is not None: + model = self._to_model(r) + if model is not None: + result.append(model) + return result + + async def find_by_id(self, id_value: str, id_field: str = "id") -> Optional[T]: + """Find a record by its primary key.""" + record = await self.table.find_unique(where={id_field: id_value}) + return self._to_model(record) + + async def find_many( + self, + where: Optional[Dict[str, Any]] = None, + skip: Optional[int] = None, + take: Optional[int] = None, + order: Optional[Dict[str, str]] = None, + ) -> List[T]: + """Find multiple records matching the criteria.""" + kwargs: Dict[str, Any] = {} + if where: + kwargs["where"] = where + if skip is not None: + kwargs["skip"] = skip + if take is not None: + kwargs["take"] = take + if order: + kwargs["order"] = order + + records = await self.table.find_many(**kwargs) + return self._to_model_list(records) + + async def create(self, data: Dict[str, Any]) -> T: + """Create a new record.""" + record = await self.table.create(data=data) + model = self._to_model(record) + assert model is not None + return model + + async def update( + self, id_value: str, data: Dict[str, Any], id_field: str = "id" + ) -> Optional[T]: + """Update an existing record.""" + record = await self.table.update(where={id_field: id_value}, data=data) + return self._to_model(record) + + async def delete(self, id_value: str, id_field: str = "id") -> Optional[T]: + """Delete a record by its primary key.""" + record = await self.table.delete(where={id_field: id_value}) + return self._to_model(record) + + async def count(self, where: Optional[Dict[str, Any]] = None) -> int: + """Count records matching the criteria.""" + return await self.table.count(where=where) + + async def exists(self, id_value: str, id_field: str = "id") -> bool: + """Check if a record exists.""" + record = await self.table.find_unique(where={id_field: id_value}) + return record is not None diff --git a/litellm/repositories/budget_repository.py b/litellm/repositories/budget_repository.py new file mode 100644 index 0000000000..5947701fb4 --- /dev/null +++ b/litellm/repositories/budget_repository.py @@ -0,0 +1,99 @@ +""" +Budget repository for database operations on LiteLLM_BudgetTable. +""" + +from typing import Any, Dict, List, Optional, Type + +from litellm.models.budget import LiteLLM_BudgetTable +from litellm.repositories.base_repository import BaseRepository + + +class BudgetRepository(BaseRepository[LiteLLM_BudgetTable]): + """Repository for budget database operations.""" + + @property + def table(self) -> Any: + return self.prisma_client.db.litellm_budgettable + + @property + def model_class(self) -> Type[LiteLLM_BudgetTable]: + return LiteLLM_BudgetTable + + async def find_by_id( + self, budget_id: str, id_field: str = "budget_id" + ) -> Optional[LiteLLM_BudgetTable]: + return await super().find_by_id(budget_id, id_field) + + async def create_budget( + self, + created_by: str, + max_budget: Optional[float] = None, + soft_budget: Optional[float] = None, + max_parallel_requests: Optional[int] = None, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + model_max_budget: Optional[Dict[str, Any]] = None, + budget_duration: Optional[str] = None, + allowed_models: Optional[List[str]] = None, + ) -> LiteLLM_BudgetTable: + """Create a new budget record.""" + data: Dict[str, Any] = { + "created_by": created_by, + "updated_by": created_by, + } + if max_budget is not None: + data["max_budget"] = max_budget + if soft_budget is not None: + data["soft_budget"] = soft_budget + if max_parallel_requests is not None: + data["max_parallel_requests"] = max_parallel_requests + if tpm_limit is not None: + data["tpm_limit"] = tpm_limit + if rpm_limit is not None: + data["rpm_limit"] = rpm_limit + if model_max_budget is not None: + data["model_max_budget"] = model_max_budget + if budget_duration is not None: + data["budget_duration"] = budget_duration + if allowed_models is not None: + data["allowed_models"] = allowed_models + + return await self.create(data) + + async def update_budget( + self, + budget_id: str, + updated_by: str, + max_budget: Optional[float] = None, + soft_budget: Optional[float] = None, + max_parallel_requests: Optional[int] = None, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + model_max_budget: Optional[Dict[str, Any]] = None, + budget_duration: Optional[str] = None, + allowed_models: Optional[List[str]] = None, + ) -> Optional[LiteLLM_BudgetTable]: + """Update an existing budget record.""" + data: Dict[str, Any] = {"updated_by": updated_by} + if max_budget is not None: + data["max_budget"] = max_budget + if soft_budget is not None: + data["soft_budget"] = soft_budget + if max_parallel_requests is not None: + data["max_parallel_requests"] = max_parallel_requests + if tpm_limit is not None: + data["tpm_limit"] = tpm_limit + if rpm_limit is not None: + data["rpm_limit"] = rpm_limit + if model_max_budget is not None: + data["model_max_budget"] = model_max_budget + if budget_duration is not None: + data["budget_duration"] = budget_duration + if allowed_models is not None: + data["allowed_models"] = allowed_models + + return await self.update(budget_id, data, id_field="budget_id") + + async def delete_budget(self, budget_id: str) -> Optional[LiteLLM_BudgetTable]: + """Delete a budget record.""" + return await self.delete(budget_id, id_field="budget_id") diff --git a/litellm/repositories/config_repository.py b/litellm/repositories/config_repository.py new file mode 100644 index 0000000000..eba7ebe26c --- /dev/null +++ b/litellm/repositories/config_repository.py @@ -0,0 +1,241 @@ +""" +Config repository for database operations on LiteLLM_Config. + +This repository handles config reconciliation between database values and +YAML configmap values. DB values override configmap values except for +None values and empty lists. +""" + +import asyncio +import copy +import json +import os +from typing import Any, Dict, List, Literal, Optional, cast + +from litellm._logging import verbose_proxy_logger +from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper + + +class ConfigParam: + """Simple wrapper for config parameter from DB.""" + + def __init__(self, param_name: str, param_value: Any): + self.param_name = param_name + self.param_value = param_value + + +class ConfigRepository: + """Repository for config database operations with reconciliation support.""" + + CONFIG_PARAMS = [ + "general_settings", + "router_settings", + "litellm_settings", + "environment_variables", + ] + + def __init__(self, prisma_client: Any): + self._prisma_client = prisma_client + + @property + def prisma_client(self) -> Any: + if self._prisma_client is None: + raise RuntimeError( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + return self._prisma_client + + @property + def table(self) -> Any: + return self.prisma_client.db.litellm_config + + async def get_param(self, param_name: str) -> Optional[ConfigParam]: + """Get a config parameter from the database.""" + record = await self.table.find_unique(where={"param_name": param_name}) + if record is None: + return None + param_value = record.param_value + if isinstance(param_value, str): + param_value = json.loads(param_value) + return ConfigParam(param_name=param_name, param_value=param_value) + + async def set_param(self, param_name: str, param_value: Any) -> ConfigParam: + """Set a config parameter in the database.""" + value_json = ( + json.dumps(param_value) if not isinstance(param_value, str) else param_value + ) + await self.table.upsert( + where={"param_name": param_name}, + data={ + "create": {"param_name": param_name, "param_value": value_json}, + "update": {"param_value": value_json}, + }, + ) + return ConfigParam(param_name=param_name, param_value=param_value) + + async def delete_param(self, param_name: str) -> bool: + """Delete a config parameter from the database.""" + try: + await self.table.delete(where={"param_name": param_name}) + return True + except Exception: + return False + + async def get_all_params(self) -> Dict[str, Any]: + """Get all config parameters from the database.""" + records = await self.table.find_many() + result = {} + for record in records: + param_value = record.param_value + if isinstance(param_value, str): + param_value = json.loads(param_value) + result[record.param_name] = param_value + return result + + def _deep_merge_dicts(self, dst: dict, src: dict) -> None: + """Deep-merge src into dst, skipping None values and empty lists from src. + + On conflicts, src (DB) wins, but empty lists are treated as "no value" + and don't overwrite the destination. + """ + stack = [(dst, src)] + while stack: + d, s = stack.pop() + for k, v in s.items(): + if v is None: + continue + if isinstance(v, list) and len(v) == 0: + continue + if isinstance(v, dict) and isinstance(d.get(k), dict): + stack.append((d[k], v)) + else: + d[k] = v + + def _decrypt_env_variables( + self, env_vars: Dict[str, Any], return_original_value: bool = True + ) -> Dict[str, str]: + """Decrypt environment variables from database.""" + decrypted: Dict[str, str] = {} + for key, value in env_vars.items(): + if isinstance(value, str): + decrypted_value = decrypt_value_helper( + value=value, + key=key, + exception_type="debug", + return_original_value=return_original_value, + ) + if decrypted_value is not None: + decrypted[key] = decrypted_value + else: + decrypted[key] = str(value) + return decrypted + + def _normalize_env_variable_keys(self, env_vars: Dict[str, str]) -> Dict[str, str]: + """Normalize env variable keys to include both original and uppercase versions.""" + normalized: Dict[str, str] = {} + for key, value in env_vars.items(): + normalized[key] = value + upper_key = key.upper() + normalized[upper_key] = value + return normalized + + def _update_config_fields( + self, + current_config: dict, + param_name: Literal[ + "general_settings", + "router_settings", + "litellm_settings", + "environment_variables", + ], + db_param_value: Any, + ) -> dict: + """Update config fields with DB values, handling the merge strategy.""" + if param_name == "environment_variables": + decrypted_env_vars = self._decrypt_env_variables( + db_param_value, return_original_value=True + ) + merged_env_vars = self._normalize_env_variable_keys(decrypted_env_vars) + for env_key, value in merged_env_vars.items(): + os.environ[env_key] = value + + current_config.setdefault("environment_variables", {}).update( + merged_env_vars + ) + return current_config + + if param_name not in current_config: + current_config[param_name] = db_param_value + return current_config + + if isinstance(current_config[param_name], dict) and isinstance( + db_param_value, dict + ): + self._deep_merge_dicts(current_config[param_name], db_param_value) + else: + current_config[param_name] = db_param_value + + return current_config + + async def reconcile_config( + self, + yaml_config: dict, + store_model_in_db: Optional[bool] = None, + ) -> dict: + """Reconcile config from YAML with database overrides. + + This is the main config reconciliation method that loads config params + from the database and merges them with the YAML config. DB values + override YAML values except for None values and empty lists. + + Args: + yaml_config: The configuration loaded from YAML file + store_model_in_db: Whether to load config from DB + + Returns: + The merged configuration with DB overrides applied + """ + if store_model_in_db is not True: + verbose_proxy_logger.info( + "'store_model_in_db' is not True, skipping db config reconciliation" + ) + return yaml_config + + tasks = [self.get_param(k) for k in self.CONFIG_PARAMS] + responses = await asyncio.gather(*tasks) + + config = copy.deepcopy(yaml_config) + for response in responses: + if response is None: + continue + + param_name = response.param_name + param_value = response.param_value + verbose_proxy_logger.debug( + f"param_name={param_name}, param_value={param_value}" + ) + + if param_name is not None and param_value is not None: + config = self._update_config_fields( + current_config=config, + param_name=cast( + Literal[ + "general_settings", + "router_settings", + "litellm_settings", + "environment_variables", + ], + param_name, + ), + db_param_value=param_value, + ) + + return config + + async def prefetch_params(self, param_names: List[str]) -> None: + """Prefetch config params to warm the cache. + + This can be called before reconcile_config to ensure all needed + params are loaded in a single batch. + """ + await asyncio.gather(*[self.get_param(k) for k in param_names]) diff --git a/litellm/repositories/credentials_repository.py b/litellm/repositories/credentials_repository.py new file mode 100644 index 0000000000..dd53c75330 --- /dev/null +++ b/litellm/repositories/credentials_repository.py @@ -0,0 +1,61 @@ +""" +Credentials repository for database operations on LiteLLM_CredentialsTable. + +This is the only place that talks to ``litellm_credentialstable``. Encryption of +credential values is the caller's responsibility (see ``CredentialHelperUtils``), +so reads return the stored values verbatim. +""" + +from typing import Any, Dict, Optional + +from litellm.models.credentials import CredentialItem + + +class CredentialsRepository: + """Repository for credentials database operations, keyed by credential name.""" + + def __init__(self, prisma_client: Any): + self._prisma_client = prisma_client + + @property + def prisma_client(self) -> Any: + if self._prisma_client is None: + raise RuntimeError( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + return self._prisma_client + + @property + def table(self) -> Any: + return self.prisma_client.db.litellm_credentialstable + + @staticmethod + def _to_model(record: Any) -> Optional[CredentialItem]: + if record is None: + return None + data = record.dict() if hasattr(record, "dict") else dict(record) + return CredentialItem( + credential_name=data["credential_name"], + credential_values=data.get("credential_values") or {}, + credential_info=data.get("credential_info") or {}, + ) + + async def find_all(self) -> Any: + return await self.table.find_many() + + async def create(self, data: Dict[str, Any]) -> Any: + return await self.table.create(data=data) + + async def find_by_name(self, credential_name: str) -> Optional[CredentialItem]: + record = await self.table.find_unique( + where={"credential_name": credential_name} + ) + return self._to_model(record) + + async def update_by_name(self, credential_name: str, data: Dict[str, Any]) -> Any: + return await self.table.update( + where={"credential_name": credential_name}, data=data + ) + + async def delete_by_name(self, credential_name: str) -> Any: + return await self.table.delete(where={"credential_name": credential_name}) diff --git a/litellm/repositories/model_repository.py b/litellm/repositories/model_repository.py new file mode 100644 index 0000000000..893cf342d7 --- /dev/null +++ b/litellm/repositories/model_repository.py @@ -0,0 +1,171 @@ +""" +Model repository for database operations on LiteLLM_ProxyModelTable. +""" + +import json +from typing import Any, Dict, List, Optional, Type + +from litellm.models.model import LiteLLM_ProxyModelTable +from litellm.repositories.base_repository import BaseRepository +from litellm.proxy.common_utils.encrypt_decrypt_utils import ( + decrypt_value_helper, + encrypt_value_helper, +) + + +class ModelRepository(BaseRepository[LiteLLM_ProxyModelTable]): + """Repository for proxy model database operations with encryption support.""" + + def __init__(self, prisma_client: Any, encryption_key: Optional[str] = None): + super().__init__(prisma_client) + self._encryption_key = encryption_key + + @property + def table(self) -> Any: + return self.prisma_client.db.litellm_proxymodeltable + + @property + def model_class(self) -> Type[LiteLLM_ProxyModelTable]: + return LiteLLM_ProxyModelTable + + def _encrypt_litellm_params(self, litellm_params: Dict[str, Any]) -> Dict[str, Any]: + """Encrypt sensitive values in litellm_params.""" + encrypted = {} + for key, value in litellm_params.items(): + if isinstance(value, str): + encrypted[key] = encrypt_value_helper( + value, new_encryption_key=self._encryption_key + ) + else: + encrypted[key] = value + return encrypted + + def _decrypt_litellm_params(self, litellm_params: Dict[str, Any]) -> Dict[str, Any]: + """Decrypt sensitive values in litellm_params.""" + decrypted = {} + for key, value in litellm_params.items(): + if isinstance(value, str): + decrypted[key] = decrypt_value_helper( + value, key=key, exception_type="debug", return_original_value=True + ) + else: + decrypted[key] = value + return decrypted + + def _to_model(self, record: Any) -> Optional[LiteLLM_ProxyModelTable]: + """Convert a database record to a Model with decryption.""" + if record is None: + return None + + data = record.dict() if hasattr(record, "dict") else dict(record) + + if isinstance(data.get("litellm_params"), str): + data["litellm_params"] = json.loads(data["litellm_params"]) + if isinstance(data.get("model_info"), str): + data["model_info"] = json.loads(data["model_info"]) + + if data.get("litellm_params"): + data["litellm_params"] = self._decrypt_litellm_params( + data["litellm_params"] + ) + + return LiteLLM_ProxyModelTable(**data) + + async def find_by_id( + self, model_id: str, id_field: str = "model_id" + ) -> Optional[LiteLLM_ProxyModelTable]: + return await super().find_by_id(model_id, id_field) + + async def find_by_name(self, model_name: str) -> List[LiteLLM_ProxyModelTable]: + """Find models by name.""" + records = await self.table.find_many(where={"model_name": model_name}) + return self._to_model_list(records) + + async def find_all(self) -> List[LiteLLM_ProxyModelTable]: + """Find all models.""" + records = await self.table.find_many() + return self._to_model_list(records) + + async def find_unblocked(self) -> List[LiteLLM_ProxyModelTable]: + """Find all models that are not blocked.""" + records = await self.table.find_many(where={"blocked": False}) + return self._to_model_list(records) + + async def find_by_team_id(self, team_id: str) -> List[LiteLLM_ProxyModelTable]: + """Find models associated with a specific team. + + Note: This filters in-memory since team_id is stored within litellm_params + JSON. For large deployments with many models, consider adding a dedicated + team_id column with a database index. + """ + all_models = await self.find_all() + return [m for m in all_models if m.team_id == team_id] + + async def create_model( + self, + model_name: str, + litellm_params: Dict[str, Any], + created_by: str, + model_id: Optional[str] = None, + model_info: Optional[Dict[str, Any]] = None, + blocked: bool = False, + ) -> LiteLLM_ProxyModelTable: + """Create a new model with encryption.""" + encrypted_params = self._encrypt_litellm_params(litellm_params) + + data: Dict[str, Any] = { + "model_name": model_name, + "litellm_params": json.dumps(encrypted_params), + "created_by": created_by, + "updated_by": created_by, + "blocked": blocked, + } + if model_id is not None: + data["model_id"] = model_id + if model_info is not None: + data["model_info"] = json.dumps(model_info) + + record = await self.table.create(data=data) + model = self._to_model(record) + assert model is not None + return model + + async def update_model( + self, + model_id: str, + updated_by: str, + model_name: Optional[str] = None, + litellm_params: Optional[Dict[str, Any]] = None, + model_info: Optional[Dict[str, Any]] = None, + blocked: Optional[bool] = None, + ) -> Optional[LiteLLM_ProxyModelTable]: + """Update a model with encryption.""" + data: Dict[str, Any] = {"updated_by": updated_by} + if model_name is not None: + data["model_name"] = model_name + if litellm_params is not None: + encrypted_params = self._encrypt_litellm_params(litellm_params) + data["litellm_params"] = json.dumps(encrypted_params) + if model_info is not None: + data["model_info"] = json.dumps(model_info) + if blocked is not None: + data["blocked"] = blocked + + record = await self.table.update(where={"model_id": model_id}, data=data) + return self._to_model(record) + + async def delete_model(self, model_id: str) -> Optional[LiteLLM_ProxyModelTable]: + """Delete a model.""" + return await self.delete(model_id, id_field="model_id") + + async def block_model( + self, model_id: str, updated_by: str + ) -> Optional[LiteLLM_ProxyModelTable]: + """Block a model.""" + return await self.update_model(model_id, updated_by, blocked=True) + + async def unblock_model( + self, model_id: str, updated_by: str + ) -> Optional[LiteLLM_ProxyModelTable]: + """Unblock a model.""" + return await self.update_model(model_id, updated_by, blocked=False) diff --git a/litellm/repositories/object_permission_repository.py b/litellm/repositories/object_permission_repository.py new file mode 100644 index 0000000000..f4d9a8bb90 --- /dev/null +++ b/litellm/repositories/object_permission_repository.py @@ -0,0 +1,110 @@ +""" +ObjectPermission repository for database operations on LiteLLM_ObjectPermissionTable. +""" + +from typing import Any, Dict, List, Optional, Type + +from litellm.models.object_permission import LiteLLM_ObjectPermissionTable +from litellm.repositories.base_repository import BaseRepository + + +class ObjectPermissionRepository(BaseRepository[LiteLLM_ObjectPermissionTable]): + """Repository for object permission database operations.""" + + @property + def table(self) -> Any: + return self.prisma_client.db.litellm_objectpermissiontable + + @property + def model_class(self) -> Type[LiteLLM_ObjectPermissionTable]: + return LiteLLM_ObjectPermissionTable + + async def find_by_id( + self, object_permission_id: str, id_field: str = "object_permission_id" + ) -> Optional[LiteLLM_ObjectPermissionTable]: + return await super().find_by_id(object_permission_id, id_field) + + async def create_permission( + self, + mcp_servers: Optional[List[str]] = None, + mcp_access_groups: Optional[List[str]] = None, + mcp_tool_permissions: Optional[Dict[str, List[str]]] = None, + vector_stores: Optional[List[str]] = None, + agents: Optional[List[str]] = None, + agent_access_groups: Optional[List[str]] = None, + models: Optional[List[str]] = None, + blocked_tools: Optional[List[str]] = None, + mcp_toolsets: Optional[List[str]] = None, + search_tools: Optional[List[str]] = None, + ) -> LiteLLM_ObjectPermissionTable: + """Create a new object permission record.""" + data: Dict[str, Any] = {} + if mcp_servers is not None: + data["mcp_servers"] = mcp_servers + if mcp_access_groups is not None: + data["mcp_access_groups"] = mcp_access_groups + if mcp_tool_permissions is not None: + data["mcp_tool_permissions"] = mcp_tool_permissions + if vector_stores is not None: + data["vector_stores"] = vector_stores + if agents is not None: + data["agents"] = agents + if agent_access_groups is not None: + data["agent_access_groups"] = agent_access_groups + if models is not None: + data["models"] = models + if blocked_tools is not None: + data["blocked_tools"] = blocked_tools + if mcp_toolsets is not None: + data["mcp_toolsets"] = mcp_toolsets + if search_tools is not None: + data["search_tools"] = search_tools + + return await self.create(data) + + async def update_permission( + self, + object_permission_id: str, + mcp_servers: Optional[List[str]] = None, + mcp_access_groups: Optional[List[str]] = None, + mcp_tool_permissions: Optional[Dict[str, List[str]]] = None, + vector_stores: Optional[List[str]] = None, + agents: Optional[List[str]] = None, + agent_access_groups: Optional[List[str]] = None, + models: Optional[List[str]] = None, + blocked_tools: Optional[List[str]] = None, + mcp_toolsets: Optional[List[str]] = None, + search_tools: Optional[List[str]] = None, + ) -> Optional[LiteLLM_ObjectPermissionTable]: + """Update an object permission record.""" + data: Dict[str, Any] = {} + if mcp_servers is not None: + data["mcp_servers"] = mcp_servers + if mcp_access_groups is not None: + data["mcp_access_groups"] = mcp_access_groups + if mcp_tool_permissions is not None: + data["mcp_tool_permissions"] = mcp_tool_permissions + if vector_stores is not None: + data["vector_stores"] = vector_stores + if agents is not None: + data["agents"] = agents + if agent_access_groups is not None: + data["agent_access_groups"] = agent_access_groups + if models is not None: + data["models"] = models + if blocked_tools is not None: + data["blocked_tools"] = blocked_tools + if mcp_toolsets is not None: + data["mcp_toolsets"] = mcp_toolsets + if search_tools is not None: + data["search_tools"] = search_tools + + return await self.update( + object_permission_id, data, id_field="object_permission_id" + ) + + async def delete_permission( + self, object_permission_id: str + ) -> Optional[LiteLLM_ObjectPermissionTable]: + """Delete an object permission record.""" + return await self.delete(object_permission_id, id_field="object_permission_id") diff --git a/litellm/repositories/organization_repository.py b/litellm/repositories/organization_repository.py new file mode 100644 index 0000000000..2d25a43e83 --- /dev/null +++ b/litellm/repositories/organization_repository.py @@ -0,0 +1,103 @@ +""" +Organization repository for database operations on LiteLLM_OrganizationTable. +""" + +from typing import Any, Dict, List, Optional, Type + +from litellm.models.organization import LiteLLM_OrganizationTable +from litellm.repositories.base_repository import BaseRepository + + +class OrganizationRepository(BaseRepository[LiteLLM_OrganizationTable]): + """Repository for organization database operations.""" + + @property + def table(self) -> Any: + return self.prisma_client.db.litellm_organizationtable + + @property + def model_class(self) -> Type[LiteLLM_OrganizationTable]: + return LiteLLM_OrganizationTable + + async def find_by_id( + self, organization_id: str, id_field: str = "organization_id" + ) -> Optional[LiteLLM_OrganizationTable]: + return await super().find_by_id(organization_id, id_field) + + async def find_by_alias( + self, organization_alias: str + ) -> Optional[LiteLLM_OrganizationTable]: + """Find an organization by alias.""" + records = await self.table.find_many( + where={"organization_alias": organization_alias} + ) + if records: + return self._to_model(records[0]) + return None + + async def create_organization( + self, + organization_alias: str, + budget_id: str, + created_by: str, + organization_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + models: Optional[List[str]] = None, + object_permission_id: Optional[str] = None, + ) -> LiteLLM_OrganizationTable: + """Create a new organization.""" + data: Dict[str, Any] = { + "organization_alias": organization_alias, + "budget_id": budget_id, + "created_by": created_by, + "updated_by": created_by, + } + if organization_id is not None: + data["organization_id"] = organization_id + if metadata is not None: + data["metadata"] = metadata + if models is not None: + data["models"] = models + if object_permission_id is not None: + data["object_permission_id"] = object_permission_id + + return await self.create(data) + + async def update_organization( + self, + organization_id: str, + updated_by: str, + organization_alias: Optional[str] = None, + budget_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + models: Optional[List[str]] = None, + object_permission_id: Optional[str] = None, + ) -> Optional[LiteLLM_OrganizationTable]: + """Update an organization.""" + data: Dict[str, Any] = {"updated_by": updated_by} + if organization_alias is not None: + data["organization_alias"] = organization_alias + if budget_id is not None: + data["budget_id"] = budget_id + if metadata is not None: + data["metadata"] = metadata + if models is not None: + data["models"] = models + if object_permission_id is not None: + data["object_permission_id"] = object_permission_id + + return await self.update(organization_id, data, id_field="organization_id") + + async def delete_organization( + self, organization_id: str + ) -> Optional[LiteLLM_OrganizationTable]: + """Delete an organization.""" + return await self.delete(organization_id, id_field="organization_id") + + async def update_spend( + self, organization_id: str, spend: float + ) -> Optional[LiteLLM_OrganizationTable]: + """Update organization spend.""" + return await self.update( + organization_id, {"spend": spend}, id_field="organization_id" + ) diff --git a/litellm/repositories/project_repository.py b/litellm/repositories/project_repository.py new file mode 100644 index 0000000000..86567dd05f --- /dev/null +++ b/litellm/repositories/project_repository.py @@ -0,0 +1,129 @@ +""" +Project repository for database operations on LiteLLM_ProjectTable. +""" + +from typing import Any, Dict, List, Optional, Type + +from litellm.models.project import LiteLLM_ProjectTable +from litellm.repositories.base_repository import BaseRepository + + +class ProjectRepository(BaseRepository[LiteLLM_ProjectTable]): + """Repository for project database operations.""" + + @property + def table(self) -> Any: + return self.prisma_client.db.litellm_projecttable + + @property + def model_class(self) -> Type[LiteLLM_ProjectTable]: + return LiteLLM_ProjectTable + + async def find_by_id( + self, project_id: str, id_field: str = "project_id" + ) -> Optional[LiteLLM_ProjectTable]: + return await super().find_by_id(project_id, id_field) + + async def find_by_alias(self, project_alias: str) -> Optional[LiteLLM_ProjectTable]: + """Find a project by alias.""" + records = await self.table.find_many(where={"project_alias": project_alias}) + if records: + return self._to_model(records[0]) + return None + + async def find_by_team_id(self, team_id: str) -> List[LiteLLM_ProjectTable]: + """Find all projects belonging to a team.""" + records = await self.table.find_many(where={"team_id": team_id}) + return self._to_model_list(records) + + async def create_project( + self, + created_by: str, + project_id: Optional[str] = None, + project_alias: Optional[str] = None, + description: Optional[str] = None, + team_id: Optional[str] = None, + budget_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + models: Optional[List[str]] = None, + model_rpm_limit: Optional[Dict[str, int]] = None, + model_tpm_limit: Optional[Dict[str, int]] = None, + object_permission_id: Optional[str] = None, + ) -> LiteLLM_ProjectTable: + """Create a new project.""" + data: Dict[str, Any] = { + "created_by": created_by, + "updated_by": created_by, + } + if project_id is not None: + data["project_id"] = project_id + if project_alias is not None: + data["project_alias"] = project_alias + if description is not None: + data["description"] = description + if team_id is not None: + data["team_id"] = team_id + if budget_id is not None: + data["budget_id"] = budget_id + if metadata is not None: + data["metadata"] = metadata + if models is not None: + data["models"] = models + if model_rpm_limit is not None: + data["model_rpm_limit"] = model_rpm_limit + if model_tpm_limit is not None: + data["model_tpm_limit"] = model_tpm_limit + if object_permission_id is not None: + data["object_permission_id"] = object_permission_id + + return await self.create(data) + + async def update_project( + self, + project_id: str, + updated_by: str, + project_alias: Optional[str] = None, + description: Optional[str] = None, + team_id: Optional[str] = None, + budget_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + models: Optional[List[str]] = None, + model_rpm_limit: Optional[Dict[str, int]] = None, + model_tpm_limit: Optional[Dict[str, int]] = None, + blocked: Optional[bool] = None, + object_permission_id: Optional[str] = None, + ) -> Optional[LiteLLM_ProjectTable]: + """Update a project.""" + data: Dict[str, Any] = {"updated_by": updated_by} + if project_alias is not None: + data["project_alias"] = project_alias + if description is not None: + data["description"] = description + if team_id is not None: + data["team_id"] = team_id + if budget_id is not None: + data["budget_id"] = budget_id + if metadata is not None: + data["metadata"] = metadata + if models is not None: + data["models"] = models + if model_rpm_limit is not None: + data["model_rpm_limit"] = model_rpm_limit + if model_tpm_limit is not None: + data["model_tpm_limit"] = model_tpm_limit + if blocked is not None: + data["blocked"] = blocked + if object_permission_id is not None: + data["object_permission_id"] = object_permission_id + + return await self.update(project_id, data, id_field="project_id") + + async def delete_project(self, project_id: str) -> Optional[LiteLLM_ProjectTable]: + """Delete a project.""" + return await self.delete(project_id, id_field="project_id") + + async def update_spend( + self, project_id: str, spend: float + ) -> Optional[LiteLLM_ProjectTable]: + """Update project spend.""" + return await self.update(project_id, {"spend": spend}, id_field="project_id") diff --git a/litellm/repositories/table_repositories.py b/litellm/repositories/table_repositories.py new file mode 100644 index 0000000000..47ea11c059 --- /dev/null +++ b/litellm/repositories/table_repositories.py @@ -0,0 +1,215 @@ +""" +Passthrough table repositories. + +Each repository centralizes access to a single Prisma table behind a ``table`` +property, making the repository the one place that names the underlying table. +These are thin wrappers for tables that do not (yet) need domain-specific query +methods; richer repositories live in their own modules. +""" + +from typing import Any + + +class PrismaTableRepository: + """Base for repositories that expose a single Prisma table.""" + + table_name: str + + def __init__(self, prisma_client: Any): + self._prisma_client = prisma_client + + @property + def prisma_client(self) -> Any: + if self._prisma_client is None: + raise RuntimeError( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + return self._prisma_client + + @property + def table(self) -> Any: + return getattr(self.prisma_client.db, self.table_name) + + +class PolicyRepository(PrismaTableRepository): + table_name = "litellm_policytable" + + +class AgentsRepository(PrismaTableRepository): + table_name = "litellm_agentstable" + + +class GuardrailsRepository(PrismaTableRepository): + table_name = "litellm_guardrailstable" + + +class MCPServerRepository(PrismaTableRepository): + table_name = "litellm_mcpservertable" + + +class ManagedObjectRepository(PrismaTableRepository): + table_name = "litellm_managedobjecttable" + + +class OrganizationMembershipRepository(PrismaTableRepository): + table_name = "litellm_organizationmembership" + + +class SpendLogsRepository(PrismaTableRepository): + table_name = "litellm_spendlogs" + + +class ClaudeCodePluginRepository(PrismaTableRepository): + table_name = "litellm_claudecodeplugintable" + + +class TeamMembershipRepository(PrismaTableRepository): + table_name = "litellm_teammembership" + + +class EndUserRepository(PrismaTableRepository): + table_name = "litellm_endusertable" + + +class ManagedVectorStoresRepository(PrismaTableRepository): + table_name = "litellm_managedvectorstorestable" + + +class MCPUserCredentialsRepository(PrismaTableRepository): + table_name = "litellm_mcpusercredentials" + + +class PromptRepository(PrismaTableRepository): + table_name = "litellm_prompttable" + + +class TagRepository(PrismaTableRepository): + table_name = "litellm_tagtable" + + +class InvitationLinkRepository(PrismaTableRepository): + table_name = "litellm_invitationlink" + + +class JWTKeyMappingRepository(PrismaTableRepository): + table_name = "litellm_jwtkeymapping" + + +class ManagedFileRepository(PrismaTableRepository): + table_name = "litellm_managedfiletable" + + +class MemoryRepository(PrismaTableRepository): + table_name = "litellm_memorytable" + + +class SearchToolsRepository(PrismaTableRepository): + table_name = "litellm_searchtoolstable" + + +class ConfigOverridesRepository(PrismaTableRepository): + table_name = "litellm_configoverrides" + + +class MCPToolsetRepository(PrismaTableRepository): + table_name = "litellm_mcptoolsettable" + + +class ToolRepository(PrismaTableRepository): + table_name = "litellm_tooltable" + + +class DeletedVerificationTokenRepository(PrismaTableRepository): + table_name = "litellm_deletedverificationtoken" + + +class WorkflowRunRepository(PrismaTableRepository): + table_name = "litellm_workflowrun" + + +class ModelTableRepository(PrismaTableRepository): + table_name = "litellm_modeltable" + + +class AccessGroupRepository(PrismaTableRepository): + table_name = "litellm_accessgrouptable" + + +class SSOConfigRepository(PrismaTableRepository): + table_name = "litellm_ssoconfig" + + +class UISettingsRepository(PrismaTableRepository): + table_name = "litellm_uisettings" + + +class DailyGuardrailMetricsRepository(PrismaTableRepository): + table_name = "litellm_dailyguardrailmetrics" + + +class PolicyAttachmentRepository(PrismaTableRepository): + table_name = "litellm_policyattachmenttable" + + +class DeletedTeamRepository(PrismaTableRepository): + table_name = "litellm_deletedteamtable" + + +class SkillsRepository(PrismaTableRepository): + table_name = "litellm_skillstable" + + +class CacheConfigRepository(PrismaTableRepository): + table_name = "litellm_cacheconfig" + + +class ManagedVectorStoreIndexRepository(PrismaTableRepository): + table_name = "litellm_managedvectorstoreindextable" + + +class WorkflowMessageRepository(PrismaTableRepository): + table_name = "litellm_workflowmessage" + + +class DailyTagSpendRepository(PrismaTableRepository): + table_name = "litellm_dailytagspend" + + +class SpendLogToolIndexRepository(PrismaTableRepository): + table_name = "litellm_spendlogtoolindex" + + +class SpendLogGuardrailIndexRepository(PrismaTableRepository): + table_name = "litellm_spendlogguardrailindex" + + +class UserNotificationsRepository(PrismaTableRepository): + table_name = "litellm_usernotifications" + + +class HealthCheckRepository(PrismaTableRepository): + table_name = "litellm_healthchecktable" + + +class DeprecatedVerificationTokenRepository(PrismaTableRepository): + table_name = "litellm_deprecatedverificationtoken" + + +class WorkflowEventRepository(PrismaTableRepository): + table_name = "litellm_workflowevent" + + +class DailyPolicyMetricsRepository(PrismaTableRepository): + table_name = "litellm_dailypolicymetrics" + + +class AdaptiveRouterStateRepository(PrismaTableRepository): + table_name = "litellm_adaptiverouterstate" + + +class AuditLogRepository(PrismaTableRepository): + table_name = "litellm_auditlog" + + +class AdaptiveRouterSessionRepository(PrismaTableRepository): + table_name = "litellm_adaptiveroutersession" diff --git a/litellm/repositories/team_repository.py b/litellm/repositories/team_repository.py new file mode 100644 index 0000000000..2ae6647060 --- /dev/null +++ b/litellm/repositories/team_repository.py @@ -0,0 +1,351 @@ +""" +Team repository for database operations on LiteLLM_TeamTable. +""" + +import json +from datetime import datetime +from typing import Any, Dict, List, Optional, Type + +from litellm.models.team import LiteLLM_TeamTable +from litellm.repositories.base_repository import BaseRepository + + +class TeamRepository(BaseRepository[LiteLLM_TeamTable]): + """Repository for team database operations.""" + + @property + def table(self) -> Any: + return self.prisma_client.db.litellm_teamtable + + @property + def deleted_table(self) -> Any: + return self.prisma_client.db.litellm_deletedteamtable + + @property + def model_class(self) -> Type[LiteLLM_TeamTable]: + return LiteLLM_TeamTable + + def _to_model(self, record: Any) -> Optional[LiteLLM_TeamTable]: + """Convert a database record to a Team model.""" + if record is None: + return None + + data = record.dict() if hasattr(record, "dict") else dict(record) + + json_fields = [ + "metadata", + "model_spend", + "model_max_budget", + "router_settings", + "budget_limits", + "members_with_roles", + ] + for field in json_fields: + if isinstance(data.get(field), str): + data[field] = json.loads(data[field]) + + return LiteLLM_TeamTable(**data) + + async def find_by_id( + self, team_id: str, id_field: str = "team_id" + ) -> Optional[LiteLLM_TeamTable]: + return await super().find_by_id(team_id, id_field) + + async def find_by_alias(self, team_alias: str) -> Optional[LiteLLM_TeamTable]: + """Find a team by alias.""" + records = await self.table.find_many(where={"team_alias": team_alias}) + if records: + return self._to_model(records[0]) + return None + + async def find_by_organization_id( + self, organization_id: str + ) -> List[LiteLLM_TeamTable]: + """Find all teams belonging to an organization.""" + records = await self.table.find_many(where={"organization_id": organization_id}) + return self._to_model_list(records) + + async def find_by_member(self, user_id: str) -> List[LiteLLM_TeamTable]: + """Find all teams where user is a member.""" + records = await self.table.find_many(where={"members": {"has": user_id}}) + return self._to_model_list(records) + + async def find_by_admin(self, user_id: str) -> List[LiteLLM_TeamTable]: + """Find all teams where user is an admin.""" + records = await self.table.find_many(where={"admins": {"has": user_id}}) + return self._to_model_list(records) + + async def create_team( + self, + team_id: str, + team_alias: Optional[str] = None, + organization_id: Optional[str] = None, + admins: Optional[List[str]] = None, + members: Optional[List[str]] = None, + members_with_roles: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + max_budget: Optional[float] = None, + soft_budget: Optional[float] = None, + models: Optional[List[str]] = None, + max_parallel_requests: Optional[int] = None, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + budget_duration: Optional[str] = None, + object_permission_id: Optional[str] = None, + ) -> LiteLLM_TeamTable: + """Create a new team.""" + data: Dict[str, Any] = {"team_id": team_id} + if team_alias is not None: + data["team_alias"] = team_alias + if organization_id is not None: + data["organization_id"] = organization_id + if admins is not None: + data["admins"] = admins + if members is not None: + data["members"] = members + if members_with_roles is not None: + data["members_with_roles"] = json.dumps(members_with_roles) + if metadata is not None: + data["metadata"] = json.dumps(metadata) + if max_budget is not None: + data["max_budget"] = max_budget + if soft_budget is not None: + data["soft_budget"] = soft_budget + if models is not None: + data["models"] = models + if max_parallel_requests is not None: + data["max_parallel_requests"] = max_parallel_requests + if tpm_limit is not None: + data["tpm_limit"] = tpm_limit + if rpm_limit is not None: + data["rpm_limit"] = rpm_limit + if budget_duration is not None: + data["budget_duration"] = budget_duration + if object_permission_id is not None: + data["object_permission_id"] = object_permission_id + + return await self.create(data) + + async def update_team( + self, + team_id: str, + team_alias: Optional[str] = None, + organization_id: Optional[str] = None, + admins: Optional[List[str]] = None, + members: Optional[List[str]] = None, + members_with_roles: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + max_budget: Optional[float] = None, + soft_budget: Optional[float] = None, + models: Optional[List[str]] = None, + max_parallel_requests: Optional[int] = None, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + budget_duration: Optional[str] = None, + blocked: Optional[bool] = None, + object_permission_id: Optional[str] = None, + ) -> Optional[LiteLLM_TeamTable]: + """Update a team.""" + data: Dict[str, Any] = {} + if team_alias is not None: + data["team_alias"] = team_alias + if organization_id is not None: + data["organization_id"] = organization_id + if admins is not None: + data["admins"] = admins + if members is not None: + data["members"] = members + if members_with_roles is not None: + data["members_with_roles"] = json.dumps(members_with_roles) + if metadata is not None: + data["metadata"] = json.dumps(metadata) + if max_budget is not None: + data["max_budget"] = max_budget + if soft_budget is not None: + data["soft_budget"] = soft_budget + if models is not None: + data["models"] = models + if max_parallel_requests is not None: + data["max_parallel_requests"] = max_parallel_requests + if tpm_limit is not None: + data["tpm_limit"] = tpm_limit + if rpm_limit is not None: + data["rpm_limit"] = rpm_limit + if budget_duration is not None: + data["budget_duration"] = budget_duration + if blocked is not None: + data["blocked"] = blocked + if object_permission_id is not None: + data["object_permission_id"] = object_permission_id + + return await self.update(team_id, data, id_field="team_id") + + async def delete_team( + self, + team_id: str, + deleted_by: Optional[str] = None, + deleted_by_api_key: Optional[str] = None, + litellm_changed_by: Optional[str] = None, + ) -> Optional[LiteLLM_TeamTable]: + """Delete a team and archive it to the deleted teams table. + + Uses a transaction to ensure atomicity of the archive-then-delete operation. + """ + team = await self.find_by_id(team_id) + if team is None: + return None + + archive_data = self._build_archive_data(team) + archive_data["deleted_by"] = deleted_by + archive_data["deleted_by_api_key"] = deleted_by_api_key + archive_data["litellm_changed_by"] = litellm_changed_by + archive_data["deleted_at"] = datetime.utcnow() + + async with self.prisma_client.db.tx() as tx: + await tx.litellm_deletedteamtable.create(data=archive_data) + await tx.litellm_teamtable.delete(where={"team_id": team_id}) + + return team + + def _build_archive_data(self, team: LiteLLM_TeamTable) -> Dict[str, Any]: + """Build archive data dict with only columns that exist in LiteLLM_DeletedTeamTable.""" + data: Dict[str, Any] = {"team_id": team.team_id} + if team.team_alias is not None: + data["team_alias"] = team.team_alias + if team.organization_id is not None: + data["organization_id"] = team.organization_id + if team.object_permission_id is not None: + data["object_permission_id"] = team.object_permission_id + data["admins"] = team.admins + data["members"] = team.members + if team.members_with_roles: + data["members_with_roles"] = json.dumps( + [m.model_dump() for m in team.members_with_roles] + ) + if team.metadata: + data["metadata"] = json.dumps(team.metadata) + if team.max_budget is not None: + data["max_budget"] = team.max_budget + if team.soft_budget is not None: + data["soft_budget"] = team.soft_budget + data["spend"] = team.spend if team.spend is not None else 0.0 + data["models"] = team.models + if team.max_parallel_requests is not None: + data["max_parallel_requests"] = team.max_parallel_requests + if team.tpm_limit is not None: + data["tpm_limit"] = team.tpm_limit + if team.rpm_limit is not None: + data["rpm_limit"] = team.rpm_limit + if team.budget_duration is not None: + data["budget_duration"] = team.budget_duration + if team.budget_reset_at is not None: + data["budget_reset_at"] = team.budget_reset_at + data["blocked"] = team.blocked + if team.model_spend: + data["model_spend"] = json.dumps(team.model_spend) + if team.model_max_budget: + data["model_max_budget"] = json.dumps(team.model_max_budget) + if team.router_settings is not None: + data["router_settings"] = json.dumps(team.router_settings) + data["team_member_permissions"] = team.team_member_permissions or [] + data["access_group_ids"] = team.access_group_ids or [] + data["policies"] = team.policies or [] + if team.model_id is not None: + data["model_id"] = team.model_id + data["allow_team_guardrail_config"] = team.allow_team_guardrail_config + return data + + async def update_spend( + self, team_id: str, spend: float + ) -> Optional[LiteLLM_TeamTable]: + """Update team spend.""" + return await self.update(team_id, {"spend": spend}, id_field="team_id") + + async def add_member( + self, team_id: str, user_id: str + ) -> Optional[LiteLLM_TeamTable]: + """Add a member to a team using atomic array push operation.""" + if not await self.exists(team_id, id_field="team_id"): + return None + + record = await self.table.update( + where={"team_id": team_id}, + data={"members": {"push": user_id}}, + ) + return self._to_model(record) + + async def remove_member( + self, team_id: str, user_id: str + ) -> Optional[LiteLLM_TeamTable]: + """Remove a member from a team. + + Note: Prisma doesn't support atomic array removal, so we use a + read-modify-write pattern here. For high-concurrency scenarios, + consider using raw SQL with array_remove(). + """ + team = await self.find_by_id(team_id) + if team is None: + return None + + members = [m for m in team.members if m != user_id] + return await self.update(team_id, {"members": members}, id_field="team_id") + + async def add_admin( + self, team_id: str, user_id: str + ) -> Optional[LiteLLM_TeamTable]: + """Add an admin to a team using atomic array push operation.""" + if not await self.exists(team_id, id_field="team_id"): + return None + + record = await self.table.update( + where={"team_id": team_id}, + data={"admins": {"push": user_id}}, + ) + return self._to_model(record) + + async def remove_admin( + self, team_id: str, user_id: str + ) -> Optional[LiteLLM_TeamTable]: + """Remove an admin from a team. + + Note: Prisma doesn't support atomic array removal, so we use a + read-modify-write pattern here. For high-concurrency scenarios, + consider using raw SQL with array_remove(). + """ + team = await self.find_by_id(team_id) + if team is None: + return None + + admins = [a for a in team.admins if a != user_id] + return await self.update(team_id, {"admins": admins}, id_field="team_id") + + async def add_models( + self, team_id: str, models: List[str] + ) -> Optional[LiteLLM_TeamTable]: + """Add models to a team's allowed models list using atomic array push.""" + if not await self.exists(team_id, id_field="team_id"): + return None + + record = await self.table.update( + where={"team_id": team_id}, + data={"models": {"push": models}}, + ) + return self._to_model(record) + + async def remove_models( + self, team_id: str, models: List[str] + ) -> Optional[LiteLLM_TeamTable]: + """Remove models from a team's allowed models list. + + Note: Prisma doesn't support atomic array removal, so we use a + read-modify-write pattern here. For high-concurrency scenarios, + consider using raw SQL with array_remove(). + """ + team = await self.find_by_id(team_id) + if team is None: + return None + + current_models = [m for m in team.models if m not in models] + return await self.update( + team_id, {"models": current_models}, id_field="team_id" + ) diff --git a/litellm/repositories/user_repository.py b/litellm/repositories/user_repository.py new file mode 100644 index 0000000000..4d28b58f0a --- /dev/null +++ b/litellm/repositories/user_repository.py @@ -0,0 +1,229 @@ +""" +User repository for database operations on LiteLLM_UserTable. +""" + +import json +from typing import Any, Dict, List, Optional, Type + +from litellm.models.user import LiteLLM_UserTable +from litellm.repositories.base_repository import BaseRepository + + +class UserRepository(BaseRepository[LiteLLM_UserTable]): + """Repository for user database operations.""" + + @property + def table(self) -> Any: + return self.prisma_client.db.litellm_usertable + + @property + def model_class(self) -> Type[LiteLLM_UserTable]: + return LiteLLM_UserTable + + def _to_model(self, record: Any) -> Optional[LiteLLM_UserTable]: + """Convert a database record to a User model.""" + if record is None: + return None + + data = record.dict() if hasattr(record, "dict") else dict(record) + + json_fields = ["metadata", "model_spend", "model_max_budget"] + for field in json_fields: + if isinstance(data.get(field), str): + data[field] = json.loads(data[field]) + + return LiteLLM_UserTable(**data) + + async def find_by_id( + self, user_id: str, id_field: str = "user_id" + ) -> Optional[LiteLLM_UserTable]: + return await super().find_by_id(user_id, id_field) + + async def find_by_email(self, user_email: str) -> Optional[LiteLLM_UserTable]: + """Find a user by email.""" + records = await self.table.find_many(where={"user_email": user_email}) + if records: + return self._to_model(records[0]) + return None + + async def find_by_sso_id(self, sso_user_id: str) -> Optional[LiteLLM_UserTable]: + """Find a user by SSO ID.""" + record = await self.table.find_unique(where={"sso_user_id": sso_user_id}) + return self._to_model(record) + + async def find_by_organization_id( + self, organization_id: str + ) -> List[LiteLLM_UserTable]: + """Find all users in an organization.""" + records = await self.table.find_many(where={"organization_id": organization_id}) + return self._to_model_list(records) + + async def find_by_team_id(self, team_id: str) -> List[LiteLLM_UserTable]: + """Find all users in a team.""" + records = await self.table.find_many(where={"teams": {"has": team_id}}) + return self._to_model_list(records) + + async def create_user( + self, + user_id: str, + user_alias: Optional[str] = None, + team_id: Optional[str] = None, + sso_user_id: Optional[str] = None, + organization_id: Optional[str] = None, + password: Optional[str] = None, + teams: Optional[List[str]] = None, + user_role: Optional[str] = None, + max_budget: Optional[float] = None, + user_email: Optional[str] = None, + models: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + max_parallel_requests: Optional[int] = None, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + budget_duration: Optional[str] = None, + allowed_cache_controls: Optional[List[str]] = None, + policies: Optional[List[str]] = None, + object_permission_id: Optional[str] = None, + ) -> LiteLLM_UserTable: + """Create a new user.""" + data: Dict[str, Any] = {"user_id": user_id} + if user_alias is not None: + data["user_alias"] = user_alias + if team_id is not None: + data["team_id"] = team_id + if sso_user_id is not None: + data["sso_user_id"] = sso_user_id + if organization_id is not None: + data["organization_id"] = organization_id + if password is not None: + data["password"] = password + if teams is not None: + data["teams"] = teams + if user_role is not None: + data["user_role"] = user_role + if max_budget is not None: + data["max_budget"] = max_budget + if user_email is not None: + data["user_email"] = user_email + if models is not None: + data["models"] = models + if metadata is not None: + data["metadata"] = json.dumps(metadata) + if max_parallel_requests is not None: + data["max_parallel_requests"] = max_parallel_requests + if tpm_limit is not None: + data["tpm_limit"] = tpm_limit + if rpm_limit is not None: + data["rpm_limit"] = rpm_limit + if budget_duration is not None: + data["budget_duration"] = budget_duration + if allowed_cache_controls is not None: + data["allowed_cache_controls"] = allowed_cache_controls + if policies is not None: + data["policies"] = policies + if object_permission_id is not None: + data["object_permission_id"] = object_permission_id + + return await self.create(data) + + async def update_user( + self, + user_id: str, + user_alias: Optional[str] = None, + team_id: Optional[str] = None, + sso_user_id: Optional[str] = None, + organization_id: Optional[str] = None, + password: Optional[str] = None, + teams: Optional[List[str]] = None, + user_role: Optional[str] = None, + max_budget: Optional[float] = None, + user_email: Optional[str] = None, + models: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + max_parallel_requests: Optional[int] = None, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + budget_duration: Optional[str] = None, + allowed_cache_controls: Optional[List[str]] = None, + policies: Optional[List[str]] = None, + object_permission_id: Optional[str] = None, + ) -> Optional[LiteLLM_UserTable]: + """Update a user.""" + data: Dict[str, Any] = {} + if user_alias is not None: + data["user_alias"] = user_alias + if team_id is not None: + data["team_id"] = team_id + if sso_user_id is not None: + data["sso_user_id"] = sso_user_id + if organization_id is not None: + data["organization_id"] = organization_id + if password is not None: + data["password"] = password + if teams is not None: + data["teams"] = teams + if user_role is not None: + data["user_role"] = user_role + if max_budget is not None: + data["max_budget"] = max_budget + if user_email is not None: + data["user_email"] = user_email + if models is not None: + data["models"] = models + if metadata is not None: + data["metadata"] = json.dumps(metadata) + if max_parallel_requests is not None: + data["max_parallel_requests"] = max_parallel_requests + if tpm_limit is not None: + data["tpm_limit"] = tpm_limit + if rpm_limit is not None: + data["rpm_limit"] = rpm_limit + if budget_duration is not None: + data["budget_duration"] = budget_duration + if allowed_cache_controls is not None: + data["allowed_cache_controls"] = allowed_cache_controls + if policies is not None: + data["policies"] = policies + if object_permission_id is not None: + data["object_permission_id"] = object_permission_id + + return await self.update(user_id, data, id_field="user_id") + + async def delete_user(self, user_id: str) -> Optional[LiteLLM_UserTable]: + """Delete a user.""" + return await self.delete(user_id, id_field="user_id") + + async def update_spend( + self, user_id: str, spend: float + ) -> Optional[LiteLLM_UserTable]: + """Update user spend.""" + return await self.update(user_id, {"spend": spend}, id_field="user_id") + + async def add_to_team( + self, user_id: str, team_id: str + ) -> Optional[LiteLLM_UserTable]: + """Add a user to a team using atomic array push operation.""" + if not await self.exists(user_id, id_field="user_id"): + return None + + record = await self.table.update( + where={"user_id": user_id}, + data={"teams": {"push": team_id}}, + ) + return self._to_model(record) + + async def remove_from_team( + self, user_id: str, team_id: str + ) -> Optional[LiteLLM_UserTable]: + """Remove a user from a team. + + Note: Prisma doesn't support atomic array removal, so we use a + read-modify-write pattern here. For high-concurrency scenarios, + consider using raw SQL with array_remove(). + """ + user = await self.find_by_id(user_id) + if user is None: + return None + + teams = [t for t in user.teams if t != team_id] + return await self.update(user_id, {"teams": teams}, id_field="user_id") diff --git a/litellm/repositories/verification_token_repository.py b/litellm/repositories/verification_token_repository.py new file mode 100644 index 0000000000..56c3e0714a --- /dev/null +++ b/litellm/repositories/verification_token_repository.py @@ -0,0 +1,375 @@ +""" +VerificationToken repository for database operations on LiteLLM_VerificationToken. +""" + +import json +from datetime import datetime +from typing import Any, Dict, List, Optional, Type + +from litellm.models.verification_token import ( + LiteLLM_VerificationToken, +) +from litellm.repositories.base_repository import BaseRepository + + +class VerificationTokenRepository(BaseRepository[LiteLLM_VerificationToken]): + """Repository for verification token (API key) database operations.""" + + @property + def table(self) -> Any: + return self.prisma_client.db.litellm_verificationtoken + + @property + def deleted_table(self) -> Any: + return self.prisma_client.db.litellm_deletedverificationtoken + + @property + def model_class(self) -> Type[LiteLLM_VerificationToken]: + return LiteLLM_VerificationToken + + def _to_model(self, record: Any) -> Optional[LiteLLM_VerificationToken]: + """Convert a database record to a VerificationToken model.""" + if record is None: + return None + + data = record.dict() if hasattr(record, "dict") else dict(record) + + json_fields = [ + "aliases", + "config", + "permissions", + "metadata", + "model_spend", + "model_max_budget", + "router_settings", + "budget_limits", + "litellm_budget_table", + ] + for field in json_fields: + if isinstance(data.get(field), str): + data[field] = json.loads(data[field]) + + if data.get("org_id") is None and data.get("organization_id") is not None: + data["org_id"] = data["organization_id"] + + return LiteLLM_VerificationToken(**data) + + async def find_by_id( + self, token: str, id_field: str = "token" + ) -> Optional[LiteLLM_VerificationToken]: + return await super().find_by_id(token, id_field) + + async def find_by_alias( + self, key_alias: str + ) -> Optional[LiteLLM_VerificationToken]: + """Find a token by key alias.""" + records = await self.table.find_many(where={"key_alias": key_alias}) + if records: + return self._to_model(records[0]) + return None + + async def find_by_user_id(self, user_id: str) -> List[LiteLLM_VerificationToken]: + """Find all tokens belonging to a user.""" + records = await self.table.find_many(where={"user_id": user_id}) + return self._to_model_list(records) + + async def find_by_team_id(self, team_id: str) -> List[LiteLLM_VerificationToken]: + """Find all tokens belonging to a team.""" + records = await self.table.find_many(where={"team_id": team_id}) + return self._to_model_list(records) + + async def find_by_project_id( + self, project_id: str + ) -> List[LiteLLM_VerificationToken]: + """Find all tokens belonging to a project.""" + records = await self.table.find_many(where={"project_id": project_id}) + return self._to_model_list(records) + + async def find_active_tokens(self) -> List[LiteLLM_VerificationToken]: + """Find all active (non-expired, non-blocked) tokens.""" + records = await self.table.find_many( + where={ + "blocked": {"not": True}, + "OR": [{"expires": None}, {"expires": {"gt": datetime.utcnow()}}], + } + ) + return self._to_model_list(records) + + def _build_token_data( + self, + token: str, + key_name: Optional[str] = None, + key_alias: Optional[str] = None, + max_budget: Optional[float] = None, + expires: Optional[datetime] = None, + models: Optional[List[str]] = None, + aliases: Optional[Dict[str, str]] = None, + config: Optional[Dict[str, Any]] = None, + user_id: Optional[str] = None, + team_id: Optional[str] = None, + agent_id: Optional[str] = None, + project_id: Optional[str] = None, + max_parallel_requests: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + budget_duration: Optional[str] = None, + allowed_cache_controls: Optional[List[str]] = None, + allowed_routes: Optional[List[str]] = None, + permissions: Optional[Dict[str, Any]] = None, + org_id: Optional[str] = None, + created_by: Optional[str] = None, + object_permission_id: Optional[str] = None, + access_group_ids: Optional[List[str]] = None, + budget_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Build data dictionary for token creation.""" + json_fields = { + "aliases": aliases, + "config": config, + "metadata": metadata, + "permissions": permissions, + } + simple_fields = { + "token": token, + "key_name": key_name, + "key_alias": key_alias, + "max_budget": max_budget, + "expires": expires, + "models": models, + "user_id": user_id, + "team_id": team_id, + "agent_id": agent_id, + "project_id": project_id, + "max_parallel_requests": max_parallel_requests, + "tpm_limit": tpm_limit, + "rpm_limit": rpm_limit, + "budget_duration": budget_duration, + "allowed_cache_controls": allowed_cache_controls, + "allowed_routes": allowed_routes, + "object_permission_id": object_permission_id, + "access_group_ids": access_group_ids, + "budget_id": budget_id, + } + data: Dict[str, Any] = {k: v for k, v in simple_fields.items() if v is not None} + for key, val in json_fields.items(): + if val is not None: + data[key] = json.dumps(val) + if org_id is not None: + data["organization_id"] = org_id + if created_by is not None: + data["created_by"] = created_by + data["updated_by"] = created_by + return data + + async def create_token( + self, + token: str, + key_name: Optional[str] = None, + key_alias: Optional[str] = None, + max_budget: Optional[float] = None, + expires: Optional[datetime] = None, + models: Optional[List[str]] = None, + aliases: Optional[Dict[str, str]] = None, + config: Optional[Dict[str, Any]] = None, + user_id: Optional[str] = None, + team_id: Optional[str] = None, + agent_id: Optional[str] = None, + project_id: Optional[str] = None, + max_parallel_requests: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + budget_duration: Optional[str] = None, + allowed_cache_controls: Optional[List[str]] = None, + allowed_routes: Optional[List[str]] = None, + permissions: Optional[Dict[str, Any]] = None, + org_id: Optional[str] = None, + created_by: Optional[str] = None, + object_permission_id: Optional[str] = None, + access_group_ids: Optional[List[str]] = None, + budget_id: Optional[str] = None, + ) -> LiteLLM_VerificationToken: + """Create a new verification token.""" + data = self._build_token_data( + token=token, + key_name=key_name, + key_alias=key_alias, + max_budget=max_budget, + expires=expires, + models=models, + aliases=aliases, + config=config, + user_id=user_id, + team_id=team_id, + agent_id=agent_id, + project_id=project_id, + max_parallel_requests=max_parallel_requests, + metadata=metadata, + tpm_limit=tpm_limit, + rpm_limit=rpm_limit, + budget_duration=budget_duration, + allowed_cache_controls=allowed_cache_controls, + allowed_routes=allowed_routes, + permissions=permissions, + org_id=org_id, + created_by=created_by, + object_permission_id=object_permission_id, + access_group_ids=access_group_ids, + budget_id=budget_id, + ) + return await self.create(data) + + async def update_token( + self, + token: str, + updated_by: Optional[str] = None, + key_name: Optional[str] = None, + key_alias: Optional[str] = None, + max_budget: Optional[float] = None, + expires: Optional[datetime] = None, + models: Optional[List[str]] = None, + aliases: Optional[Dict[str, str]] = None, + config: Optional[Dict[str, Any]] = None, + max_parallel_requests: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + budget_duration: Optional[str] = None, + allowed_cache_controls: Optional[List[str]] = None, + allowed_routes: Optional[List[str]] = None, + permissions: Optional[Dict[str, Any]] = None, + blocked: Optional[bool] = None, + object_permission_id: Optional[str] = None, + access_group_ids: Optional[List[str]] = None, + ) -> Optional[LiteLLM_VerificationToken]: + """Update a verification token.""" + data: Dict[str, Any] = {} + if updated_by is not None: + data["updated_by"] = updated_by + if key_name is not None: + data["key_name"] = key_name + if key_alias is not None: + data["key_alias"] = key_alias + if max_budget is not None: + data["max_budget"] = max_budget + if expires is not None: + data["expires"] = expires + if models is not None: + data["models"] = models + if aliases is not None: + data["aliases"] = json.dumps(aliases) + if config is not None: + data["config"] = json.dumps(config) + if max_parallel_requests is not None: + data["max_parallel_requests"] = max_parallel_requests + if metadata is not None: + data["metadata"] = json.dumps(metadata) + if tpm_limit is not None: + data["tpm_limit"] = tpm_limit + if rpm_limit is not None: + data["rpm_limit"] = rpm_limit + if budget_duration is not None: + data["budget_duration"] = budget_duration + if allowed_cache_controls is not None: + data["allowed_cache_controls"] = allowed_cache_controls + if allowed_routes is not None: + data["allowed_routes"] = allowed_routes + if permissions is not None: + data["permissions"] = json.dumps(permissions) + if blocked is not None: + data["blocked"] = blocked + if object_permission_id is not None: + data["object_permission_id"] = object_permission_id + if access_group_ids is not None: + data["access_group_ids"] = access_group_ids + + return await self.update(token, data, id_field="token") + + async def delete_token( + self, + token: str, + deleted_by: Optional[str] = None, + deleted_by_api_key: Optional[str] = None, + litellm_changed_by: Optional[str] = None, + ) -> Optional[LiteLLM_VerificationToken]: + """Delete a token and archive it to the deleted tokens table. + + Uses a transaction to ensure atomicity of the archive-then-delete operation. + """ + token_record = await self.find_by_id(token) + if token_record is None: + return None + + archive_data = self._build_archive_data(token_record) + archive_data["deleted_by"] = deleted_by + archive_data["deleted_by_api_key"] = deleted_by_api_key + archive_data["litellm_changed_by"] = litellm_changed_by + archive_data["deleted_at"] = datetime.utcnow() + + async with self.prisma_client.db.tx() as tx: + await tx.litellm_deletedverificationtoken.create(data=archive_data) + await tx.litellm_verificationtoken.delete(where={"token": token}) + + return token_record + + def _build_archive_data(self, token: LiteLLM_VerificationToken) -> Dict[str, Any]: + """Build archive data with only columns present in LiteLLM_DeletedVerificationToken. + + Serializes JSON columns to strings (the archive table stores them as JSON + columns the same way the live table does) and maps ``org_id`` onto the + ``organization_id`` column so the foreign key is preserved. + """ + data = token.model_dump(exclude_none=True) + for field in ("object_permission", "litellm_budget_table", "budget_limits"): + data.pop(field, None) + + org_id = data.pop("org_id", None) + if org_id is not None: + data["organization_id"] = org_id + + json_fields = [ + "aliases", + "config", + "permissions", + "metadata", + "model_spend", + "model_max_budget", + "router_settings", + ] + for field in json_fields: + if field in data: + data[field] = json.dumps(data[field]) + return data + + async def update_spend( + self, token: str, spend: float + ) -> Optional[LiteLLM_VerificationToken]: + """Update token spend.""" + return await self.update(token, {"spend": spend}, id_field="token") + + async def update_last_active( + self, token: str + ) -> Optional[LiteLLM_VerificationToken]: + """Update the last_active timestamp.""" + return await self.update( + token, {"last_active": datetime.utcnow()}, id_field="token" + ) + + async def block_token( + self, token: str, updated_by: Optional[str] = None + ) -> Optional[LiteLLM_VerificationToken]: + """Block a token.""" + data: Dict[str, Any] = {"blocked": True} + if updated_by is not None: + data["updated_by"] = updated_by + return await self.update(token, data, id_field="token") + + async def unblock_token( + self, token: str, updated_by: Optional[str] = None + ) -> Optional[LiteLLM_VerificationToken]: + """Unblock a token.""" + data: Dict[str, Any] = {"blocked": False} + if updated_by is not None: + data["updated_by"] = updated_by + return await self.update(token, data, id_field="token") diff --git a/litellm/router_strategy/adaptive_router/adaptive_router.py b/litellm/router_strategy/adaptive_router/adaptive_router.py index 3bccef36e6..4856d7ff4c 100644 --- a/litellm/router_strategy/adaptive_router/adaptive_router.py +++ b/litellm/router_strategy/adaptive_router/adaptive_router.py @@ -55,6 +55,7 @@ from litellm.router_strategy.adaptive_router.update_queue import ( _SESSION_STATE_SWEEP_THRESHOLD: int = 1024 # Same pattern for the owner cache. _OWNER_CACHE_SWEEP_THRESHOLD: int = 1024 +from litellm.repositories.table_repositories import AdaptiveRouterStateRepository from litellm.types.llms.openai import AllMessageValues from litellm.types.router import ( AdaptiveRouterConfig, @@ -113,7 +114,7 @@ class AdaptiveRouter: if prisma_client is None: return try: - rows = await prisma_client.db.litellm_adaptiverouterstate.find_many( + rows = await AdaptiveRouterStateRepository(prisma_client).table.find_many( where={"router_name": self.router_name} ) loaded = 0 diff --git a/litellm/router_strategy/adaptive_router/update_queue.py b/litellm/router_strategy/adaptive_router/update_queue.py index b667f3a53a..1d87feddd8 100644 --- a/litellm/router_strategy/adaptive_router/update_queue.py +++ b/litellm/router_strategy/adaptive_router/update_queue.py @@ -22,6 +22,10 @@ import asyncio from typing import Any, Dict, Tuple from litellm._logging import verbose_router_logger +from litellm.repositories.table_repositories import ( + AdaptiveRouterSessionRepository, + AdaptiveRouterStateRepository, +) StateKey = Tuple[str, str, str] # (router_name, request_type, model_name) SessionKey = Tuple[str, str, str] # (session_id, router_name, model_name) @@ -112,7 +116,7 @@ class AdaptiveRouterUpdateQueue: # other. The upsert creates the row with the delta as the # initial value on first write, then increments on subsequent # writes — no read-modify-write race. - await prisma_client.db.litellm_adaptiverouterstate.upsert( + await AdaptiveRouterStateRepository(prisma_client).table.upsert( where={ "router_name_request_type_model_name": { "router_name": router, @@ -174,7 +178,7 @@ class AdaptiveRouterUpdateQueue: for k, v in payload.items() if k not in ("session_id", "router_name", "model_name") } - await prisma_client.db.litellm_adaptiveroutersession.upsert( + await AdaptiveRouterSessionRepository(prisma_client).table.upsert( where={ "session_id_router_name_model_name": { "session_id": session_id, diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index 92ca027c5b..809da6418d 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -3,8 +3,7 @@ from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, ConfigDict -from litellm.proxy._types import MCPAuthType, MCPTransportType -from litellm.types.mcp import MCPAuth +from litellm.types.mcp import MCPAuth, MCPAuthType, MCPTransportType # MCPInfo now allows arbitrary additional fields for custom metadata MCPInfo = Dict[str, Any] diff --git a/litellm/types/utils.py b/litellm/types/utils.py index b76ae1f5d8..9633cecf96 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -38,7 +38,6 @@ from pydantic import ( Field, PrivateAttr, field_validator, - model_validator, ) from typing_extensions import Required, TypedDict @@ -3577,25 +3576,11 @@ class RawRequestTypedDict(TypedDict, total=False): error: Optional[str] -class CredentialBase(BaseModel): - credential_name: str - credential_info: dict - - -class CredentialItem(CredentialBase): - credential_values: dict - - -class CreateCredentialItem(CredentialBase): - credential_values: Optional[dict] = None - model_id: Optional[str] = None - - @model_validator(mode="before") - @classmethod - def check_credential_params(cls, values): - if not values.get("credential_values") and not values.get("model_id"): - raise ValueError("Either credential_values or model_id must be set") - return values +from litellm.models.credentials import CredentialBase as CredentialBase # noqa: E402 +from litellm.models.credentials import CredentialItem as CredentialItem # noqa: E402 +from litellm.models.credentials import ( # noqa: E402 + CreateCredentialItem as CreateCredentialItem, +) class ExtractedFileData(TypedDict): diff --git a/litellm/vector_stores/vector_store_registry.py b/litellm/vector_stores/vector_store_registry.py index 1fd95b1630..94f0483e1c 100644 --- a/litellm/vector_stores/vector_store_registry.py +++ b/litellm/vector_stores/vector_store_registry.py @@ -5,6 +5,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, get_args from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import remove_items_at_indices +from litellm.repositories.table_repositories import ( + ManagedVectorStoreIndexRepository, + ManagedVectorStoresRepository, +) from litellm.types.vector_stores import ( VECTOR_STORE_OPENAI_PARAMS, LiteLLM_ManagedVectorStore, @@ -91,10 +95,10 @@ class VectorStoreIndexRegistry: """ vector_stores_from_db: List[LiteLLM_ManagedVectorStoreIndex] = [] if prisma_client is not None: - _vector_stores_from_db = ( - await prisma_client.db.litellm_managedvectorstoreindextable.find_many( - order={"created_at": "desc"}, - ) + _vector_stores_from_db = await ManagedVectorStoreIndexRepository( + prisma_client + ).table.find_many( + order={"created_at": "desc"}, ) for vector_store in _vector_stores_from_db: _dict_vector_store = dict(vector_store) @@ -374,9 +378,9 @@ class VectorStoreRegistry: if vector_store is not None and prisma_client is not None: try: # Check if it still exists in database - db_vector_store = await prisma_client.db.litellm_managedvectorstorestable.find_unique( - where={"vector_store_id": vector_store_id} - ) + db_vector_store = await ManagedVectorStoresRepository( + prisma_client + ).table.find_unique(where={"vector_store_id": vector_store_id}) if db_vector_store is None: # Vector store was deleted from database, remove from cache verbose_logger.debug( @@ -541,10 +545,10 @@ class VectorStoreRegistry: """ vector_stores_from_db: List[LiteLLM_ManagedVectorStore] = [] if prisma_client is not None: - _vector_stores_from_db = ( - await prisma_client.db.litellm_managedvectorstorestable.find_many( - order={"created_at": "desc"}, - ) + _vector_stores_from_db = await ManagedVectorStoresRepository( + prisma_client + ).table.find_many( + order={"created_at": "desc"}, ) for vector_store in _vector_stores_from_db: _dict_vector_store = dict(vector_store) diff --git a/tests/test_litellm/models/test_models.py b/tests/test_litellm/models/test_models.py new file mode 100644 index 0000000000..786f624493 --- /dev/null +++ b/tests/test_litellm/models/test_models.py @@ -0,0 +1,542 @@ +""" +Tests for backend domain models. +""" + +from datetime import datetime + +import pytest + +from litellm.models.access_group import LiteLLM_AccessGroupTable +from litellm.models.budget import ( + LiteLLM_BudgetTable, + LiteLLM_BudgetTableFull, + LiteLLM_TeamMemberTable, +) +from litellm.models.config import LiteLLM_Config +from litellm.models.credentials import CreateCredentialItem, CredentialItem +from litellm.models.end_user import LiteLLM_EndUserTable +from litellm.models.managed_files import ( + LiteLLM_ManagedFileTable, + LiteLLM_ManagedObjectTable, + LiteLLM_ManagedVectorStoresTable, +) +from litellm.models.mcp_server import LiteLLM_MCPServerTable +from litellm.models.model import LiteLLM_ProxyModelTable +from litellm.models.object_permission import LiteLLM_ObjectPermissionTable +from litellm.models.organization import LiteLLM_OrganizationTable +from litellm.models.project import LiteLLM_ProjectTable +from litellm.models.skills import LiteLLM_SkillsTable +from litellm.models.spend_logs import LiteLLM_ErrorLogs, LiteLLM_SpendLogs +from litellm.models.tag import LiteLLM_TagTable +from litellm.models.team import ( + LiteLLM_DeletedTeamTable, + LiteLLM_TeamTable, + LiteLLM_TeamTableCachedObj, +) +from litellm.models.team_membership import LiteLLM_TeamMembership +from litellm.models.user import LiteLLM_UserTable +from litellm.models.verification_token import ( + LiteLLM_DeletedVerificationToken, + LiteLLM_VerificationToken, +) + + +class TestBudget: + def test_budget_creation(self): + budget = LiteLLM_BudgetTable( + budget_id="test-budget-id", + max_budget=100.0, + soft_budget=80.0, + tpm_limit=1000, + rpm_limit=100, + model_max_budget={"gpt-4": 50.0}, + budget_duration="monthly", + allowed_models=["gpt-4"], + ) + assert budget.budget_id == "test-budget-id" + assert budget.max_budget == 100.0 + assert budget.soft_budget == 80.0 + assert budget.tpm_limit == 1000 + assert budget.rpm_limit == 100 + assert budget.model_max_budget == {"gpt-4": 50.0} + assert budget.budget_duration == "monthly" + assert budget.allowed_models == ["gpt-4"] + + def test_budget_defaults(self): + budget = LiteLLM_BudgetTable() + assert budget.budget_id is None + assert budget.max_budget is None + assert budget.allowed_models is None + + +class TestCredentials: + def test_credentials_creation(self): + creds = CredentialItem( + credential_name="test-cred", + credential_values={"api_key": "secret123"}, + credential_info={"provider": "openai"}, + ) + assert creds.credential_name == "test-cred" + assert creds.credential_values["api_key"] == "secret123" + assert creds.credential_info["provider"] == "openai" + + def test_create_credential_item_accepts_model_id(self): + item = CreateCredentialItem( + credential_name="from-model", + credential_info={}, + model_id="model-123", + ) + assert item.model_id == "model-123" + assert item.credential_values is None + + def test_create_credential_item_requires_values_or_model_id(self): + with pytest.raises( + ValueError, match="Either credential_values or model_id must be set" + ): + CreateCredentialItem(credential_name="bad", credential_info={}) + + +class TestModel: + def test_model_creation(self): + model = LiteLLM_ProxyModelTable( + model_id="test-model-id", + model_name="gpt-4", + litellm_params={"model": "gpt-4", "api_key": "test"}, + model_info={"team_id": "team-123", "team_public_model_name": "my-gpt4"}, + ) + assert model.model_id == "test-model-id" + assert model.model_name == "gpt-4" + assert model.team_id == "team-123" + assert model.team_public_model_name == "my-gpt4" + + def test_is_blocked(self): + model_blocked = LiteLLM_ProxyModelTable( + model_id="m1", model_name="test", litellm_params={}, blocked=True + ) + model_unblocked = LiteLLM_ProxyModelTable( + model_id="m2", model_name="test", litellm_params={}, blocked=False + ) + assert model_blocked.is_blocked + assert not model_unblocked.is_blocked + + def test_parses_json_string_fields(self): + model = LiteLLM_ProxyModelTable( + model_id="m1", + model_name="gpt-4", + litellm_params='{"model": "gpt-4"}', + model_info='{"team_id": "t1"}', + ) + assert model.litellm_params == {"model": "gpt-4"} + assert model.model_info == {"team_id": "t1"} + + def test_team_helpers_none_when_no_model_info(self): + model = LiteLLM_ProxyModelTable( + model_id="m1", model_name="gpt-4", litellm_params={}, model_info=None + ) + assert model.team_id is None + assert model.team_public_model_name is None + + +class TestObjectPermission: + def test_object_permission_creation(self): + perm = LiteLLM_ObjectPermissionTable( + object_permission_id="test-perm-id", + mcp_servers=["server1", "server2"], + vector_stores=["vs1"], + agents=["agent1"], + models=["gpt-4"], + blocked_tools=["dangerous_tool"], + ) + assert perm.object_permission_id == "test-perm-id" + assert len(perm.mcp_servers) == 2 + assert perm.vector_stores == ["vs1"] + assert perm.agents == ["agent1"] + assert perm.models == ["gpt-4"] + assert perm.blocked_tools == ["dangerous_tool"] + + def test_object_permission_tool_permissions(self): + perm = LiteLLM_ObjectPermissionTable( + object_permission_id="perm-tools", + mcp_tool_permissions={"server1": ["tool1", "tool2"]}, + ) + assert perm.mcp_tool_permissions == {"server1": ["tool1", "tool2"]} + + +class TestOrganization: + def test_organization_creation(self): + org = LiteLLM_OrganizationTable( + organization_id="org-123", + organization_alias="My Org", + budget_id="budget-123", + models=["gpt-4", "claude-3"], + spend=50.0, + created_by="admin", + updated_by="admin", + ) + assert org.organization_id == "org-123" + assert org.organization_alias == "My Org" + assert len(org.models) == 2 + + +class TestProject: + def test_project_creation(self): + project = LiteLLM_ProjectTable( + project_id="proj-123", + project_alias="My Project", + team_id="team-123", + blocked=False, + ) + assert project.project_id == "proj-123" + assert not project.is_blocked + + +class TestTeam: + def test_team_creation(self): + team = LiteLLM_TeamTable( + team_id="team-123", + team_alias="Engineering", + admins=["user1"], + members=["user2", "user3"], + models=["gpt-4"], + max_budget=1000.0, + spend=100.0, + ) + assert team.team_id == "team-123" + assert team.team_alias == "Engineering" + assert team.admins == ["user1"] + assert team.members == ["user2", "user3"] + assert team.models == ["gpt-4"] + assert team.max_budget == 1000.0 + + def test_members_with_roles_parsing(self): + team = LiteLLM_TeamTable( + team_id="t2", + members_with_roles=[ + {"user_id": "user1", "role": "admin"}, + {"user_id": "user2", "role": "user"}, + ], + ) + assert len(team.members_with_roles) == 2 + assert team.members_with_roles[0].user_id == "user1" + assert team.members_with_roles[0].role == "admin" + + def test_members_with_roles_empty_dict_coerced(self): + team = LiteLLM_TeamTable(team_id="t3", members_with_roles={}) + assert team.members_with_roles == [] + + def test_json_string_fields_parsed(self): + team = LiteLLM_TeamTable( + team_id="t4", + metadata='{"k": "v"}', + model_max_budget='{"gpt-4": 5.0}', + ) + assert team.metadata == {"k": "v"} + assert team.model_max_budget == {"gpt-4": 5.0} + + def test_cached_team(self): + cached = LiteLLM_TeamTableCachedObj( + team_id="t1", last_refreshed_at=1234567890.0 + ) + assert cached.last_refreshed_at == 1234567890.0 + + def test_deleted_team(self): + deleted = LiteLLM_DeletedTeamTable( + team_id="t1", + deleted_by="admin", + deleted_at=datetime.utcnow(), + ) + assert deleted.deleted_by == "admin" + assert deleted.deleted_at is not None + + +class TestUser: + def test_user_creation(self): + user = LiteLLM_UserTable( + user_id="user-123", + user_email="test@example.com", + teams=["team1", "team2"], + max_budget=100.0, + spend=25.0, + ) + assert user.user_id == "user-123" + assert user.user_email == "test@example.com" + assert len(user.teams) == 2 + + def test_is_over_budget(self): + user = LiteLLM_UserTable(user_id="u1", max_budget=100.0, spend=150.0) + user_no_budget = LiteLLM_UserTable(user_id="u2", spend=1000.0) + + assert user.is_over_budget() + assert not user_no_budget.is_over_budget() + + def test_has_model_access(self): + user_with_models = LiteLLM_UserTable(user_id="u1", models=["gpt-4"]) + user_no_models = LiteLLM_UserTable(user_id="u2", models=[]) + + assert user_with_models.has_model_access("gpt-4") + assert not user_with_models.has_model_access("gpt-3") + assert user_no_models.has_model_access("any-model") + + def test_password_hash_excluded_from_serialization(self): + from litellm.proxy._types import LiteLLM_UserTableWithKeyCount + + secret = "$2b$12$abcdefghijklmnopqrstuv" + user = LiteLLM_UserTable(user_id="u1", user_email="a@b.c", password=secret) + + assert user.password == secret + assert "password" not in user.model_dump() + assert "password" not in user.model_dump_json() + + with_keys = LiteLLM_UserTableWithKeyCount( + user_id="u1", user_email="a@b.c", password=secret, key_count=2 + ) + assert with_keys.password == secret + assert "password" not in with_keys.model_dump() + assert "password" not in with_keys.model_dump_json() + + +class TestVerificationToken: + def test_verification_token_creation(self): + token = LiteLLM_VerificationToken( + token="sk-test123", + key_name="Test Key", + user_id="user-123", + team_id="team-123", + max_budget=100.0, + spend=25.0, + models=["gpt-4"], + blocked=True, + allowed_routes=["/chat/completions"], + ) + assert token.token == "sk-test123" + assert token.key_name == "Test Key" + assert token.user_id == "user-123" + assert token.team_id == "team-123" + assert token.blocked is True + assert token.models == ["gpt-4"] + assert token.allowed_routes == ["/chat/completions"] + + def test_expires_accepts_string_and_datetime(self): + as_str = LiteLLM_VerificationToken(token="t1", expires="2024-12-31T23:59:59Z") + as_dt = LiteLLM_VerificationToken(token="t2", expires=datetime.utcnow()) + assert as_str.expires == "2024-12-31T23:59:59Z" + assert isinstance(as_dt.expires, datetime) + + def test_deleted_verification_token(self): + deleted = LiteLLM_DeletedVerificationToken( + token="t1", + deleted_by="admin", + deleted_at=datetime.utcnow(), + ) + assert deleted.deleted_by == "admin" + assert deleted.deleted_at is not None + assert deleted.token == "t1" + + +class TestConfigTable: + def test_config_creation(self): + cfg = LiteLLM_Config(param_name="general_settings", param_value={"k": "v"}) + assert cfg.param_name == "general_settings" + assert cfg.param_value == {"k": "v"} + + +class TestSkillsTable: + def test_skills_creation(self): + skill = LiteLLM_SkillsTable( + skill_id="s1", + display_title="My Skill", + source="custom", + file_content=b"zipbytes", + file_name="skill.zip", + ) + assert skill.skill_id == "s1" + assert skill.display_title == "My Skill" + assert skill.file_content == b"zipbytes" + + def test_skills_defaults(self): + skill = LiteLLM_SkillsTable(skill_id="s2") + assert skill.source == "custom" + assert skill.metadata is None + + +class TestAccessGroupTable: + def test_access_group_creation(self): + ag = LiteLLM_AccessGroupTable( + access_group_id="ag1", + access_group_name="group-a", + access_model_names=["gpt-4"], + assigned_team_ids=["t1"], + ) + assert ag.access_group_id == "ag1" + assert ag.access_model_names == ["gpt-4"] + assert ag.assigned_team_ids == ["t1"] + assert ag.access_agent_ids == [] + + +class TestTagTable: + def test_tag_creation(self): + tag = LiteLLM_TagTable( + tag_name="prod", + models=["gpt-4"], + spend=12.5, + budget_id="b1", + ) + assert tag.tag_name == "prod" + assert tag.models == ["gpt-4"] + assert tag.spend == 12.5 + + def test_tag_set_model_info_coerces_none(self): + tag = LiteLLM_TagTable(tag_name="t", spend=None, models=None) + assert tag.spend == 0.0 + assert tag.models == [] + + +class TestEndUserTable: + def test_end_user_creation(self): + eu = LiteLLM_EndUserTable( + user_id="eu1", + blocked=False, + spend=5.0, + allowed_model_region="eu", + default_model="gpt-4", + ) + assert eu.user_id == "eu1" + assert eu.blocked is False + assert eu.allowed_model_region == "eu" + assert eu.default_model == "gpt-4" + + def test_end_user_spend_coerced_when_none(self): + eu = LiteLLM_EndUserTable(user_id="eu2", blocked=True, spend=None) + assert eu.spend == 0.0 + + +class TestBudgetTableFull: + def test_full_adds_server_managed_fields(self): + now = datetime.now() + budget = LiteLLM_BudgetTableFull( + budget_id="b1", max_budget=10.0, created_at=now, budget_reset_at=now + ) + assert budget.created_at == now + assert budget.budget_reset_at == now + assert budget.max_budget == 10.0 + + def test_full_requires_created_at(self): + with pytest.raises(Exception): + LiteLLM_BudgetTableFull(budget_id="b1") + + +class TestTeamMemberTable: + def test_tracks_user_within_team(self): + member = LiteLLM_TeamMemberTable( + user_id="u1", team_id="t1", spend=3.0, budget_id="b1", max_budget=5.0 + ) + assert member.user_id == "u1" + assert member.team_id == "t1" + assert member.spend == 3.0 + assert member.max_budget == 5.0 + + +class TestTeamMembership: + def test_safe_get_limits_with_budget_table(self): + membership = LiteLLM_TeamMembership( + user_id="u1", + team_id="t1", + litellm_budget_table=LiteLLM_BudgetTable(rpm_limit=100, tpm_limit=2000), + ) + assert membership.safe_get_team_member_rpm_limit() == 100 + assert membership.safe_get_team_member_tpm_limit() == 2000 + + def test_safe_get_limits_without_budget_table(self): + membership = LiteLLM_TeamMembership(user_id="u1", team_id="t1") + assert membership.safe_get_team_member_rpm_limit() is None + assert membership.safe_get_team_member_tpm_limit() is None + + def test_full_budget_variant_parsed_for_server_fields(self): + now = datetime.now() + membership = LiteLLM_TeamMembership( + user_id="u1", + team_id="t1", + litellm_budget_table={ + "budget_id": "b1", + "rpm_limit": 7, + "created_at": now, + "budget_reset_at": now, + }, + ) + assert isinstance(membership.litellm_budget_table, LiteLLM_BudgetTableFull) + assert membership.safe_get_team_member_rpm_limit() == 7 + + +class TestMCPServerTable: + def test_mcp_server_defaults(self): + server = LiteLLM_MCPServerTable(server_id="s1", transport="sse") + assert server.server_id == "s1" + assert server.transport == "sse" + assert server.status == "unknown" + assert server.approval_status == "active" + assert server.allow_all_keys is False + assert server.available_on_public_internet is True + assert server.teams == [] + assert server.env == {} + + def test_mcp_server_requires_transport(self): + with pytest.raises(Exception): + LiteLLM_MCPServerTable(server_id="s1") + + +class TestSpendLogs: + def test_spend_logs_creation(self): + log = LiteLLM_SpendLogs( + request_id="r1", + api_key="sk-1", + call_type="completion", + startTime=None, + endTime=None, + messages=None, + response=None, + ) + assert log.request_id == "r1" + assert log.spend == 0.0 + assert log.cache_hit == "False" + + def test_error_logs_creation(self): + log = LiteLLM_ErrorLogs( + request_id="r1", startTime=None, endTime=None, status_code="500" + ) + assert log.request_id == "r1" + assert log.status_code == "500" + + +class TestManagedTables: + def test_managed_file_table(self): + table = LiteLLM_ManagedFileTable( + unified_file_id="f1", + model_mappings={"gpt-4": "file-abc"}, + flat_model_file_ids=["file-abc"], + ) + assert table.unified_file_id == "f1" + assert table.model_mappings == {"gpt-4": "file-abc"} + assert table.flat_model_file_ids == ["file-abc"] + + def test_managed_object_table_requires_purpose(self): + with pytest.raises(Exception): + LiteLLM_ManagedObjectTable( + unified_object_id="o1", model_object_id="m1", file_object={} + ) + + def test_managed_vector_stores_table(self): + table = LiteLLM_ManagedVectorStoresTable( + vector_store_id="vs1", + custom_llm_provider="openai", + vector_store_name=None, + vector_store_description=None, + vector_store_metadata=None, + created_at=None, + updated_at=None, + litellm_credential_name=None, + litellm_params=None, + team_id=None, + user_id=None, + ) + assert table.vector_store_id == "vs1" + assert table.custom_llm_provider == "openai" diff --git a/tests/test_litellm/repositories/test_repositories.py b/tests/test_litellm/repositories/test_repositories.py new file mode 100644 index 0000000000..f22debbae3 --- /dev/null +++ b/tests/test_litellm/repositories/test_repositories.py @@ -0,0 +1,2184 @@ +""" +Tests for gateway repository layer. +""" + +import json +from datetime import datetime +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, patch + +import pytest + +from litellm.models.base import DomainModel +from litellm.models.budget import LiteLLM_BudgetTable +from litellm.models.credentials import CredentialItem +from litellm.models.team import LiteLLM_TeamTable +from litellm.repositories.base_repository import BaseRepository +from litellm.repositories.budget_repository import BudgetRepository +from litellm.repositories.config_repository import ConfigRepository +from litellm.repositories.credentials_repository import CredentialsRepository +from litellm.repositories.model_repository import ModelRepository +from litellm.repositories.object_permission_repository import ( + ObjectPermissionRepository, +) +from litellm.repositories.organization_repository import OrganizationRepository +from litellm.repositories.project_repository import ProjectRepository +from litellm.repositories.team_repository import TeamRepository +from litellm.repositories.user_repository import UserRepository +from litellm.repositories.verification_token_repository import ( + VerificationTokenRepository, +) + + +class MockRecord: + """Mock database record for testing.""" + + def __init__(self, data: Dict[str, Any]): + self._data = data if data is not None else {} + + def dict(self) -> Dict[str, Any]: + return self._data.copy() + + def model_dump(self) -> Dict[str, Any]: + return self._data.copy() + + def __getattr__(self, name: str) -> Any: + if name.startswith("_"): + raise AttributeError(name) + return self._data.get(name) + + +class MockTable: + """Mock Prisma table for testing.""" + + def __init__(self, pk_field: Optional[str] = None): + self._records: Dict[str, Dict[str, Any]] = {} + self._pk_field = pk_field + + async def find_unique(self, where: Dict[str, Any]) -> Optional[MockRecord]: + key_field = list(where.keys())[0] + key_value = where[key_field] + data = self._records.get(key_value) + return MockRecord(data) if data else None + + async def find_many( + self, + where: Optional[Dict[str, Any]] = None, + skip: Optional[int] = None, + take: Optional[int] = None, + order: Optional[Dict[str, str]] = None, + ) -> List[MockRecord]: + records = list(self._records.values()) + return [MockRecord(r) for r in records] + + async def create(self, data: Dict[str, Any]) -> MockRecord: + record_data = dict(data) + if self._pk_field and self._pk_field not in record_data: + record_data[self._pk_field] = f"{self._pk_field}-{len(self._records)}" + key = ( + record_data.get(self._pk_field) + if self._pk_field + else record_data.get("id", str(len(self._records))) + ) + self._records[key] = record_data + return MockRecord(record_data) + + async def update( + self, where: Dict[str, Any], data: Dict[str, Any] + ) -> Optional[MockRecord]: + key_field = list(where.keys())[0] + key_value = where[key_field] + if key_value in self._records: + for field, value in data.items(): + if isinstance(value, dict) and "push" in value: + current = self._records[key_value].get(field, []) + push_val = value["push"] + if isinstance(push_val, list): + current.extend(push_val) + else: + current.append(push_val) + self._records[key_value][field] = current + else: + self._records[key_value][field] = value + return MockRecord(self._records[key_value]) + return None + + async def delete(self, where: Dict[str, Any]) -> Optional[MockRecord]: + key_field = list(where.keys())[0] + key_value = where[key_field] + data = self._records.pop(key_value, None) + return MockRecord(data) if data else None + + async def count(self, where: Optional[Dict[str, Any]] = None) -> int: + return len(self._records) + + async def upsert(self, where: Dict[str, Any], data: Dict[str, Any]) -> MockRecord: + key_field = list(where.keys())[0] + key_value = where[key_field] + if key_value in self._records: + self._records[key_value].update(data.get("update", {})) + else: + self._records[key_value] = data.get("create", {}) + return MockRecord(self._records[key_value]) + + +class MockPrismaClient: + """Mock Prisma client for testing.""" + + def __init__(self): + self.db = MagicMock() + self.db.litellm_budgettable = MockTable() + self.db.litellm_proxymodeltable = MockTable(pk_field="model_id") + self.db.litellm_teamtable = MockTable() + self.db.litellm_deletedteamtable = MockTable() + self.db.litellm_usertable = MockTable() + self.db.litellm_verificationtoken = MockTable() + self.db.litellm_deletedverificationtoken = MockTable() + self.db.litellm_config = MockTable() + self.db.litellm_organizationtable = MockTable() + self.db.litellm_projecttable = MockTable(pk_field="project_id") + self.db.litellm_objectpermissiontable = MockTable( + pk_field="object_permission_id" + ) + self.db.litellm_credentialstable = MockTable() + + +class TestBaseRepository: + @pytest.fixture + def prisma_client(self): + return MockPrismaClient() + + def test_prisma_client_none_raises(self): + class TestRepo(BaseRepository[LiteLLM_BudgetTable]): + @property + def table(self): + return None + + @property + def model_class(self): + return LiteLLM_BudgetTable + + repo = TestRepo(None) + with pytest.raises(RuntimeError, match="No DB Connected"): + _ = repo.prisma_client + + @pytest.mark.asyncio + async def test_find_many(self, prisma_client): + repo = BudgetRepository(prisma_client) + prisma_client.db.litellm_budgettable._records = { + "b1": {"budget_id": "b1", "max_budget": 100.0}, + "b2": {"budget_id": "b2", "max_budget": 200.0}, + } + budgets = await repo.find_many() + assert len(budgets) == 2 + + @pytest.mark.asyncio + async def test_count(self, prisma_client): + repo = BudgetRepository(prisma_client) + prisma_client.db.litellm_budgettable._records = { + "b1": {"budget_id": "b1"}, + "b2": {"budget_id": "b2"}, + } + count = await repo.count() + assert count == 2 + + @pytest.mark.asyncio + async def test_exists(self, prisma_client): + repo = BudgetRepository(prisma_client) + prisma_client.db.litellm_budgettable._records = { + "b1": {"budget_id": "b1"}, + } + assert await repo.exists("b1", id_field="budget_id") + assert not await repo.exists("nonexistent", id_field="budget_id") + + @pytest.mark.asyncio + async def test_find_many_with_all_kwargs(self, prisma_client): + repo = BudgetRepository(prisma_client) + prisma_client.db.litellm_budgettable._records = { + "b1": {"budget_id": "b1", "max_budget": 100.0}, + } + budgets = await repo.find_many( + where={"budget_id": "b1"}, skip=0, take=10, order={"budget_id": "asc"} + ) + assert len(budgets) == 1 + + def test_record_to_dict_branches(self): + from litellm.repositories.base_repository import _record_to_dict + + assert _record_to_dict({"a": 1}) == {"a": 1} + + class WithModelDump: + def model_dump(self): + return {"src": "model_dump"} + + assert _record_to_dict(WithModelDump()) == {"src": "model_dump"} + + class WithDict: + def dict(self): + return {"src": "dict"} + + assert _record_to_dict(WithDict()) == {"src": "dict"} + + assert _record_to_dict([("k", "v")]) == {"k": "v"} + + +class TestBudgetRepository: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return BudgetRepository(client) + + @pytest.mark.asyncio + async def test_create_budget(self, repo): + budget = await repo.create_budget( + created_by="test-user", + max_budget=100.0, + soft_budget=80.0, + tpm_limit=1000, + ) + assert budget.max_budget == 100.0 + assert budget.soft_budget == 80.0 + assert budget.tpm_limit == 1000 + + @pytest.mark.asyncio + async def test_create_budget_all_fields(self, repo): + budget = await repo.create_budget( + created_by="test-user", + max_budget=100.0, + soft_budget=80.0, + max_parallel_requests=10, + tpm_limit=1000, + rpm_limit=100, + model_max_budget={"gpt-4": 50.0}, + budget_duration="monthly", + allowed_models=["gpt-4", "gpt-3.5-turbo"], + ) + assert budget.max_budget == 100.0 + assert budget.max_parallel_requests == 10 + + @pytest.mark.asyncio + async def test_update_budget(self, repo): + await repo.create_budget(created_by="test-user", max_budget=100.0) + repo._prisma_client.db.litellm_budgettable._records["budget-1"] = { + "budget_id": "budget-1", + "max_budget": 100.0, + } + + updated = await repo.update_budget( + budget_id="budget-1", + updated_by="test-user", + max_budget=200.0, + ) + assert updated.max_budget == 200.0 + + @pytest.mark.asyncio + async def test_delete_budget(self, repo): + repo._prisma_client.db.litellm_budgettable._records["budget-1"] = { + "budget_id": "budget-1", + "max_budget": 100.0, + } + deleted = await repo.delete_budget("budget-1") + assert deleted is not None + assert "budget-1" not in repo._prisma_client.db.litellm_budgettable._records + + @pytest.mark.asyncio + async def test_find_by_id(self, repo): + repo._prisma_client.db.litellm_budgettable._records["budget-1"] = { + "budget_id": "budget-1", + "max_budget": 100.0, + } + budget = await repo.find_by_id("budget-1") + assert budget is not None + assert budget.budget_id == "budget-1" + + +class TestModelRepository: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return ModelRepository(client) + + @pytest.mark.asyncio + @patch( + "litellm.repositories.model_repository.encrypt_value_helper", + side_effect=lambda v, **kw: f"encrypted_{v}", + ) + @patch( + "litellm.repositories.model_repository.decrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + async def test_create_model_encrypts_params(self, mock_decrypt, mock_encrypt, repo): + model = await repo.create_model( + model_name="gpt-4", + litellm_params={"api_key": "sk-secret"}, + created_by="test-user", + ) + assert model is not None + mock_encrypt.assert_called() + + @pytest.mark.asyncio + @patch( + "litellm.repositories.model_repository.encrypt_value_helper", + side_effect=lambda v, **kw: f"encrypted_{v}", + ) + @patch( + "litellm.repositories.model_repository.decrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + async def test_create_model_all_fields(self, mock_decrypt, mock_encrypt, repo): + model = await repo.create_model( + model_name="gpt-4-turbo", + litellm_params={ + "api_key": "sk-secret", + "api_base": "https://api.openai.com", + }, + created_by="admin", + model_id="custom-model-id", + model_info={"team_id": "team-1", "description": "GPT-4 Turbo model"}, + blocked=True, + ) + assert model is not None + assert model.model_name == "gpt-4-turbo" + + @pytest.mark.asyncio + @patch( + "litellm.repositories.model_repository.encrypt_value_helper", + side_effect=lambda v, **kw: f"encrypted_{v}", + ) + @patch( + "litellm.repositories.model_repository.decrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + async def test_update_model_all_fields(self, mock_decrypt, mock_encrypt, repo): + repo._prisma_client.db.litellm_proxymodeltable._records["model-full"] = { + "model_id": "model-full", + "model_name": "old-name", + "litellm_params": '{"api_key": "old"}', + "blocked": False, + } + updated = await repo.update_model( + model_id="model-full", + updated_by="admin", + model_name="new-name", + litellm_params={"api_key": "new-key"}, + model_info={"updated": True}, + blocked=True, + ) + assert updated.model_name == "new-name" + + @pytest.mark.asyncio + @patch( + "litellm.repositories.model_repository.decrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + async def test_find_all(self, mock_decrypt, repo): + repo._prisma_client.db.litellm_proxymodeltable._records = { + "m1": { + "model_id": "m1", + "model_name": "gpt-4", + "litellm_params": '{"model": "gpt-4"}', + "blocked": False, + }, + "m2": { + "model_id": "m2", + "model_name": "claude-3", + "litellm_params": '{"model": "claude-3"}', + "blocked": False, + }, + } + models = await repo.find_all() + assert len(models) == 2 + + @pytest.mark.asyncio + @patch( + "litellm.repositories.model_repository.decrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + async def test_find_unblocked(self, mock_decrypt, repo): + repo._prisma_client.db.litellm_proxymodeltable._records = { + "m1": { + "model_id": "m1", + "model_name": "gpt-4", + "litellm_params": '{"model": "gpt-4"}', + "blocked": False, + }, + } + models = await repo.find_unblocked() + assert len(models) == 1 + + @pytest.mark.asyncio + @patch( + "litellm.repositories.model_repository.decrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + async def test_find_by_name(self, mock_decrypt, repo): + repo._prisma_client.db.litellm_proxymodeltable._records = { + "m1": { + "model_id": "m1", + "model_name": "gpt-4", + "litellm_params": '{"model": "gpt-4"}', + }, + } + models = await repo.find_by_name("gpt-4") + assert len(models) == 1 + + @pytest.mark.asyncio + @patch( + "litellm.repositories.model_repository.encrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + @patch( + "litellm.repositories.model_repository.decrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + async def test_update_model(self, mock_decrypt, mock_encrypt, repo): + repo._prisma_client.db.litellm_proxymodeltable._records["m1"] = { + "model_id": "m1", + "model_name": "gpt-4", + "litellm_params": '{"model": "gpt-4"}', + "blocked": False, + } + updated = await repo.update_model( + model_id="m1", + updated_by="test-user", + blocked=True, + ) + assert updated.blocked is True + + @pytest.mark.asyncio + @patch( + "litellm.repositories.model_repository.decrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + async def test_delete_model(self, mock_decrypt, repo): + repo._prisma_client.db.litellm_proxymodeltable._records["m1"] = { + "model_id": "m1", + "model_name": "gpt-4", + "litellm_params": '{"model": "gpt-4"}', + } + deleted = await repo.delete_model("m1") + assert deleted is not None + + @pytest.mark.asyncio + @patch( + "litellm.repositories.model_repository.encrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + @patch( + "litellm.repositories.model_repository.decrypt_value_helper", + side_effect=lambda v, **kw: v, + ) + async def test_block_unblock_model(self, mock_decrypt, mock_encrypt, repo): + repo._prisma_client.db.litellm_proxymodeltable._records["m1"] = { + "model_id": "m1", + "model_name": "gpt-4", + "litellm_params": '{"model": "gpt-4"}', + "blocked": False, + } + blocked = await repo.block_model("m1", "admin") + assert blocked.blocked is True + + unblocked = await repo.unblock_model("m1", "admin") + assert unblocked.blocked is False + + +class TestTeamRepository: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return TeamRepository(client) + + @pytest.mark.asyncio + async def test_create_team(self, repo): + team = await repo.create_team( + team_id="team-123", + team_alias="Engineering", + admins=["user1"], + members=["user2", "user3"], + ) + assert team.team_id == "team-123" + assert team.team_alias == "Engineering" + + @pytest.mark.asyncio + async def test_create_team_all_fields(self, repo): + team = await repo.create_team( + team_id="team-123", + team_alias="Engineering", + organization_id="org-1", + admins=["admin1"], + members=["user1"], + members_with_roles=[{"user_id": "user1", "role": "user"}], + metadata={"dept": "engineering"}, + max_budget=1000.0, + soft_budget=800.0, + models=["gpt-4"], + max_parallel_requests=10, + tpm_limit=50000, + rpm_limit=500, + budget_duration="monthly", + object_permission_id="perm-1", + ) + assert team.team_id == "team-123" + assert team.organization_id == "org-1" + + @pytest.mark.asyncio + async def test_update_team(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "team_alias": "Test", + "admins": [], + "members": [], + "models": [], + } + updated = await repo.update_team( + team_id="team-1", + team_alias="Updated Team", + blocked=True, + ) + assert updated.team_alias == "Updated Team" + + @pytest.mark.asyncio + async def test_update_team_all_fields(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-full"] = { + "team_id": "team-full", + "team_alias": "Test", + "admins": [], + "members": [], + "models": [], + } + updated = await repo.update_team( + team_id="team-full", + team_alias="Fully Updated", + organization_id="org-new", + admins=["admin1"], + members=["member1"], + members_with_roles=[{"user_id": "user1", "role": "admin"}], + metadata={"updated": True}, + max_budget=500.0, + soft_budget=400.0, + models=["gpt-4", "claude-3"], + max_parallel_requests=20, + tpm_limit=100000, + rpm_limit=1000, + budget_duration="weekly", + blocked=False, + object_permission_id="perm-new", + ) + assert updated.team_alias == "Fully Updated" + + @pytest.mark.asyncio + async def test_add_member(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "team_alias": "Test", + "admins": [], + "members": ["user1"], + "models": [], + } + + team = await repo.add_member("team-1", "user2") + assert "user2" in team.members + + @pytest.mark.asyncio + async def test_add_member_nonexistent_team(self, repo): + result = await repo.add_member("nonexistent", "user1") + assert result is None + + @pytest.mark.asyncio + async def test_remove_member(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "team_alias": "Test", + "admins": [], + "members": ["user1", "user2"], + "models": [], + } + + team = await repo.remove_member("team-1", "user2") + assert "user2" not in team.members + + @pytest.mark.asyncio + async def test_add_admin(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "team_alias": "Test", + "admins": [], + "members": [], + "models": [], + } + team = await repo.add_admin("team-1", "admin1") + assert "admin1" in team.admins + + @pytest.mark.asyncio + async def test_remove_admin(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "team_alias": "Test", + "admins": ["admin1", "admin2"], + "members": [], + "models": [], + } + team = await repo.remove_admin("team-1", "admin2") + assert "admin2" not in team.admins + + @pytest.mark.asyncio + async def test_add_models(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "team_alias": "Test", + "admins": [], + "members": [], + "models": ["gpt-3.5-turbo"], + } + team = await repo.add_models("team-1", ["gpt-4"]) + assert "gpt-4" in team.models + + @pytest.mark.asyncio + async def test_remove_models(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "team_alias": "Test", + "admins": [], + "members": [], + "models": ["gpt-3.5-turbo", "gpt-4"], + } + team = await repo.remove_models("team-1", ["gpt-4"]) + assert "gpt-4" not in team.models + + @pytest.mark.asyncio + async def test_update_spend(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "team_alias": "Test", + "admins": [], + "members": [], + "models": [], + "spend": 0.0, + } + team = await repo.update_spend("team-1", 50.0) + assert team.spend == 50.0 + + @pytest.mark.asyncio + async def test_find_by_alias(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "team_alias": "Engineering", + "admins": [], + "members": [], + "models": [], + } + team = await repo.find_by_alias("Engineering") + assert team is not None + assert team.team_id == "team-1" + + @pytest.mark.asyncio + async def test_find_by_organization_id(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "organization_id": "org-1", + "admins": [], + "members": [], + "models": [], + } + teams = await repo.find_by_organization_id("org-1") + assert len(teams) == 1 + + @pytest.mark.asyncio + async def test_find_by_member(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "admins": [], + "members": ["user1"], + "models": [], + } + teams = await repo.find_by_member("user1") + assert len(teams) == 1 + + @pytest.mark.asyncio + async def test_find_by_admin(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-1"] = { + "team_id": "team-1", + "admins": ["admin1"], + "members": [], + "models": [], + } + teams = await repo.find_by_admin("admin1") + assert len(teams) == 1 + + +class TestUserRepository: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return UserRepository(client) + + @pytest.mark.asyncio + async def test_create_user(self, repo): + user = await repo.create_user( + user_id="user-123", + user_email="test@example.com", + teams=["team1"], + ) + assert user.user_id == "user-123" + + @pytest.mark.asyncio + async def test_create_user_all_fields(self, repo): + user = await repo.create_user( + user_id="user-123", + user_alias="testuser", + team_id="team-1", + sso_user_id="sso-123", + organization_id="org-1", + password="hashed_password", + teams=["team1", "team2"], + user_role="admin", + max_budget=500.0, + user_email="test@example.com", + models=["gpt-4"], + metadata={"department": "engineering"}, + max_parallel_requests=5, + tpm_limit=10000, + rpm_limit=100, + budget_duration="monthly", + allowed_cache_controls=["no-cache"], + policies=["policy-1"], + object_permission_id="perm-1", + ) + assert user.user_id == "user-123" + assert user.user_alias == "testuser" + + @pytest.mark.asyncio + async def test_update_user(self, repo): + repo._prisma_client.db.litellm_usertable._records["user-1"] = { + "user_id": "user-1", + "teams": [], + "models": [], + } + updated = await repo.update_user( + user_id="user-1", + user_email="updated@example.com", + ) + assert updated.user_email == "updated@example.com" + + @pytest.mark.asyncio + async def test_delete_user(self, repo): + repo._prisma_client.db.litellm_usertable._records["user-1"] = { + "user_id": "user-1", + "teams": [], + "models": [], + } + deleted = await repo.delete_user("user-1") + assert deleted is not None + + @pytest.mark.asyncio + async def test_add_to_team(self, repo): + repo._prisma_client.db.litellm_usertable._records["user-1"] = { + "user_id": "user-1", + "teams": ["team1"], + "models": [], + } + + user = await repo.add_to_team("user-1", "team2") + assert "team2" in user.teams + + @pytest.mark.asyncio + async def test_add_to_team_nonexistent_user(self, repo): + result = await repo.add_to_team("nonexistent", "team1") + assert result is None + + @pytest.mark.asyncio + async def test_remove_from_team(self, repo): + repo._prisma_client.db.litellm_usertable._records["user-1"] = { + "user_id": "user-1", + "teams": ["team1", "team2"], + "models": [], + } + user = await repo.remove_from_team("user-1", "team2") + assert "team2" not in user.teams + + @pytest.mark.asyncio + async def test_update_spend(self, repo): + repo._prisma_client.db.litellm_usertable._records["user-1"] = { + "user_id": "user-1", + "teams": [], + "models": [], + "spend": 0.0, + } + user = await repo.update_spend("user-1", 25.0) + assert user.spend == 25.0 + + @pytest.mark.asyncio + async def test_find_by_email(self, repo): + repo._prisma_client.db.litellm_usertable._records["user-1"] = { + "user_id": "user-1", + "user_email": "test@example.com", + "teams": [], + "models": [], + } + user = await repo.find_by_email("test@example.com") + assert user is not None + + @pytest.mark.asyncio + async def test_find_by_sso_id(self, repo): + repo._prisma_client.db.litellm_usertable._records["sso-123"] = { + "user_id": "user-1", + "sso_user_id": "sso-123", + "teams": [], + "models": [], + } + user = await repo.find_by_sso_id("sso-123") + assert user is not None + + @pytest.mark.asyncio + async def test_find_by_organization_id(self, repo): + repo._prisma_client.db.litellm_usertable._records["user-1"] = { + "user_id": "user-1", + "organization_id": "org-1", + "teams": [], + "models": [], + } + users = await repo.find_by_organization_id("org-1") + assert len(users) == 1 + + @pytest.mark.asyncio + async def test_find_by_team_id(self, repo): + repo._prisma_client.db.litellm_usertable._records["user-1"] = { + "user_id": "user-1", + "teams": ["team-1"], + "models": [], + } + users = await repo.find_by_team_id("team-1") + assert len(users) == 1 + + +class TestVerificationTokenRepository: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return VerificationTokenRepository(client) + + @pytest.mark.asyncio + async def test_create_token(self, repo): + token = await repo.create_token( + token="sk-test123", + key_name="Test Key", + user_id="user-123", + max_budget=100.0, + ) + assert token.token == "sk-test123" + assert token.key_name == "Test Key" + + @pytest.mark.asyncio + async def test_create_token_all_fields(self, repo): + token = await repo.create_token( + token="sk-test123", + key_name="Test Key", + key_alias="test-alias", + max_budget=100.0, + expires=datetime(2025, 12, 31), + models=["gpt-4"], + aliases={"alias1": "value1"}, + config={"setting": "value"}, + user_id="user-123", + team_id="team-1", + agent_id="agent-1", + project_id="project-1", + max_parallel_requests=5, + metadata={"key": "value"}, + tpm_limit=10000, + rpm_limit=100, + budget_duration="monthly", + allowed_cache_controls=["no-cache"], + allowed_routes=["/v1/completions"], + permissions={"read": True}, + org_id="org-1", + created_by="admin", + object_permission_id="perm-1", + access_group_ids=["group-1"], + budget_id="budget-1", + ) + assert token.token == "sk-test123" + + @pytest.mark.asyncio + async def test_update_token(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-test"] = { + "token": "sk-test", + "blocked": False, + } + updated = await repo.update_token( + token="sk-test", + key_name="Updated Key", + ) + assert updated.key_name == "Updated Key" + + @pytest.mark.asyncio + async def test_block_token(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-test"] = { + "token": "sk-test", + "blocked": False, + } + + token = await repo.block_token("sk-test", updated_by="admin") + assert token.blocked is True + + @pytest.mark.asyncio + async def test_unblock_token(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-test"] = { + "token": "sk-test", + "blocked": True, + } + token = await repo.unblock_token("sk-test", updated_by="admin") + assert token.blocked is False + + @pytest.mark.asyncio + async def test_update_spend(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-test"] = { + "token": "sk-test", + "spend": 0.0, + } + token = await repo.update_spend("sk-test", 15.0) + assert token.spend == 15.0 + + @pytest.mark.asyncio + async def test_update_last_active(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-test"] = { + "token": "sk-test", + } + token = await repo.update_last_active("sk-test") + assert token.last_active is not None + + @pytest.mark.asyncio + async def test_find_by_alias(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-test"] = { + "token": "sk-test", + "key_alias": "my-key", + } + token = await repo.find_by_alias("my-key") + assert token is not None + + @pytest.mark.asyncio + async def test_find_by_user_id(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-test"] = { + "token": "sk-test", + "user_id": "user-1", + } + tokens = await repo.find_by_user_id("user-1") + assert len(tokens) == 1 + + @pytest.mark.asyncio + async def test_find_by_team_id(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-test"] = { + "token": "sk-test", + "team_id": "team-1", + } + tokens = await repo.find_by_team_id("team-1") + assert len(tokens) == 1 + + @pytest.mark.asyncio + async def test_find_by_project_id(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-test"] = { + "token": "sk-test", + "project_id": "project-1", + } + tokens = await repo.find_by_project_id("project-1") + assert len(tokens) == 1 + + +class TestOrganizationRepository: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return OrganizationRepository(client) + + @pytest.mark.asyncio + async def test_create_organization(self, repo): + org = await repo.create_organization( + organization_alias="Acme Corp", + budget_id="budget-1", + created_by="admin", + ) + assert org.organization_alias == "Acme Corp" + + @pytest.mark.asyncio + async def test_create_organization_all_fields(self, repo): + org = await repo.create_organization( + organization_alias="Acme Corp", + budget_id="budget-1", + created_by="admin", + organization_id="org-123", + metadata={"industry": "tech"}, + models=["gpt-4"], + object_permission_id="perm-1", + ) + assert org.organization_alias == "Acme Corp" + + @pytest.mark.asyncio + async def test_update_organization(self, repo): + repo._prisma_client.db.litellm_organizationtable._records["org-1"] = { + "organization_id": "org-1", + "organization_alias": "Old Name", + "budget_id": "b1", + "created_by": "admin", + "updated_by": "admin", + } + updated = await repo.update_organization( + organization_id="org-1", + updated_by="admin", + organization_alias="New Name", + ) + assert updated.organization_alias == "New Name" + + @pytest.mark.asyncio + async def test_update_organization_all_fields(self, repo): + repo._prisma_client.db.litellm_organizationtable._records["org-full"] = { + "organization_id": "org-full", + "organization_alias": "Old Name", + "budget_id": "b1", + "created_by": "admin", + "updated_by": "admin", + } + updated = await repo.update_organization( + organization_id="org-full", + updated_by="admin", + organization_alias="Fully Updated", + budget_id="budget-new", + metadata={"updated": True}, + models=["gpt-4", "claude-3"], + object_permission_id="perm-new", + ) + assert updated.organization_alias == "Fully Updated" + + @pytest.mark.asyncio + async def test_delete_organization(self, repo): + repo._prisma_client.db.litellm_organizationtable._records["org-1"] = { + "organization_id": "org-1", + "organization_alias": "Acme", + "budget_id": "b1", + "created_by": "admin", + "updated_by": "admin", + } + deleted = await repo.delete_organization("org-1") + assert deleted is not None + + @pytest.mark.asyncio + async def test_update_spend(self, repo): + repo._prisma_client.db.litellm_organizationtable._records["org-1"] = { + "organization_id": "org-1", + "organization_alias": "Acme", + "spend": 0.0, + "budget_id": "b1", + "created_by": "admin", + "updated_by": "admin", + } + org = await repo.update_spend("org-1", 100.0) + assert org.spend == 100.0 + + @pytest.mark.asyncio + async def test_find_by_alias(self, repo): + repo._prisma_client.db.litellm_organizationtable._records["org-1"] = { + "organization_id": "org-1", + "organization_alias": "Acme", + "budget_id": "b1", + "created_by": "admin", + "updated_by": "admin", + } + org = await repo.find_by_alias("Acme") + assert org is not None + + +class TestProjectRepository: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return ProjectRepository(client) + + @pytest.mark.asyncio + async def test_create_project(self, repo): + project = await repo.create_project( + created_by="admin", + project_alias="My Project", + ) + assert project.project_alias == "My Project" + + @pytest.mark.asyncio + async def test_create_project_all_fields(self, repo): + project = await repo.create_project( + created_by="admin", + project_id="proj-123", + project_alias="My Project", + description="A test project", + team_id="team-1", + budget_id="budget-1", + metadata={"env": "dev"}, + models=["gpt-4"], + model_rpm_limit={"gpt-4": 100}, + model_tpm_limit={"gpt-4": 10000}, + object_permission_id="perm-1", + ) + assert project.project_alias == "My Project" + + @pytest.mark.asyncio + async def test_update_project(self, repo): + repo._prisma_client.db.litellm_projecttable._records["proj-1"] = { + "project_id": "proj-1", + "project_alias": "Old Name", + } + updated = await repo.update_project( + project_id="proj-1", + updated_by="admin", + project_alias="New Name", + blocked=True, + ) + assert updated.project_alias == "New Name" + + @pytest.mark.asyncio + async def test_update_project_all_fields(self, repo): + repo._prisma_client.db.litellm_projecttable._records["proj-full"] = { + "project_id": "proj-full", + "project_alias": "Old Name", + } + updated = await repo.update_project( + project_id="proj-full", + updated_by="admin", + project_alias="Fully Updated", + description="New description", + team_id="team-new", + budget_id="budget-new", + metadata={"updated": True}, + models=["gpt-4", "claude-3"], + model_rpm_limit={"gpt-4": 200}, + model_tpm_limit={"gpt-4": 20000}, + blocked=False, + object_permission_id="perm-new", + ) + assert updated.project_alias == "Fully Updated" + + @pytest.mark.asyncio + async def test_delete_project(self, repo): + repo._prisma_client.db.litellm_projecttable._records["proj-1"] = { + "project_id": "proj-1", + } + deleted = await repo.delete_project("proj-1") + assert deleted is not None + + @pytest.mark.asyncio + async def test_update_spend(self, repo): + repo._prisma_client.db.litellm_projecttable._records["proj-1"] = { + "project_id": "proj-1", + "spend": 0.0, + } + project = await repo.update_spend("proj-1", 50.0) + assert project.spend == 50.0 + + @pytest.mark.asyncio + async def test_find_by_alias(self, repo): + repo._prisma_client.db.litellm_projecttable._records["proj-1"] = { + "project_id": "proj-1", + "project_alias": "MyProject", + } + project = await repo.find_by_alias("MyProject") + assert project is not None + + @pytest.mark.asyncio + async def test_find_by_team_id(self, repo): + repo._prisma_client.db.litellm_projecttable._records["proj-1"] = { + "project_id": "proj-1", + "team_id": "team-1", + } + projects = await repo.find_by_team_id("team-1") + assert len(projects) == 1 + + +class TestObjectPermissionRepository: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return ObjectPermissionRepository(client) + + @pytest.mark.asyncio + async def test_create_permission(self, repo): + perm = await repo.create_permission( + mcp_servers=["server1"], + models=["gpt-4"], + ) + assert perm.mcp_servers == ["server1"] + + @pytest.mark.asyncio + async def test_create_permission_all_fields(self, repo): + perm = await repo.create_permission( + mcp_servers=["server1"], + mcp_access_groups=["group1"], + mcp_tool_permissions={"tool1": ["read", "write"]}, + vector_stores=["store1"], + agents=["agent1"], + agent_access_groups=["agent-group1"], + models=["gpt-4"], + blocked_tools=["tool2"], + mcp_toolsets=["toolset1"], + search_tools=["search1"], + ) + assert perm.mcp_servers == ["server1"] + assert perm.agents == ["agent1"] + + @pytest.mark.asyncio + async def test_update_permission(self, repo): + repo._prisma_client.db.litellm_objectpermissiontable._records["perm-1"] = { + "object_permission_id": "perm-1", + "models": ["gpt-3.5-turbo"], + } + updated = await repo.update_permission( + object_permission_id="perm-1", + models=["gpt-4"], + ) + assert updated.models == ["gpt-4"] + + @pytest.mark.asyncio + async def test_update_permission_all_fields(self, repo): + repo._prisma_client.db.litellm_objectpermissiontable._records["perm-full"] = { + "object_permission_id": "perm-full", + "models": [], + } + updated = await repo.update_permission( + object_permission_id="perm-full", + mcp_servers=["server-new"], + mcp_access_groups=["group-new"], + mcp_tool_permissions={"tool": ["exec"]}, + vector_stores=["store-new"], + agents=["agent-new"], + agent_access_groups=["ag-new"], + models=["gpt-4", "claude-3"], + blocked_tools=["blocked-tool"], + mcp_toolsets=["toolset-new"], + search_tools=["search-new"], + ) + assert updated.mcp_servers == ["server-new"] + + @pytest.mark.asyncio + async def test_delete_permission(self, repo): + repo._prisma_client.db.litellm_objectpermissiontable._records["perm-1"] = { + "object_permission_id": "perm-1", + } + deleted = await repo.delete_permission("perm-1") + assert deleted is not None + + +class TestCredentialsRepository: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return CredentialsRepository(client) + + @pytest.mark.asyncio + async def test_create(self, repo): + record = await repo.create( + data={ + "credential_name": "my-api-key", + "credential_values": {"api_key": "encrypted_secret"}, + "credential_info": {"provider": "openai"}, + "created_by": "admin", + "updated_by": "admin", + } + ) + assert record.credential_name == "my-api-key" + cred = repo._to_model(record) + assert cred.credential_name == "my-api-key" + assert cred.credential_info == {"provider": "openai"} + assert cred.credential_values == {"api_key": "encrypted_secret"} + + @pytest.mark.asyncio + async def test_find_by_name_returns_stored_values_without_decryption(self, repo): + repo._prisma_client.db.litellm_credentialstable._records["my-key"] = { + "credential_id": "cred-1", + "credential_name": "my-key", + "credential_values": {"api_key": "encrypted_secret"}, + "credential_info": {"provider": "openai"}, + } + cred = await repo.find_by_name("my-key") + assert isinstance(cred, CredentialItem) + assert cred.credential_values == {"api_key": "encrypted_secret"} + assert cred.credential_info == {"provider": "openai"} + + @pytest.mark.asyncio + async def test_find_by_name_missing(self, repo): + assert await repo.find_by_name("nonexistent") is None + + @pytest.mark.asyncio + async def test_update_by_name(self, repo): + repo._prisma_client.db.litellm_credentialstable._records["my-key"] = { + "credential_id": "cred-1", + "credential_name": "my-key", + "credential_values": {"api_key": "old"}, + "credential_info": {}, + } + await repo.update_by_name( + "my-key", + data={"credential_values": {"api_key": "new"}, "updated_by": "admin"}, + ) + cred = await repo.find_by_name("my-key") + assert cred.credential_values == {"api_key": "new"} + + @pytest.mark.asyncio + async def test_delete_by_name(self, repo): + repo._prisma_client.db.litellm_credentialstable._records["my-key"] = { + "credential_id": "cred-1", + "credential_name": "my-key", + "credential_values": {"api_key": "secret"}, + "credential_info": {}, + } + await repo.delete_by_name("my-key") + assert await repo.find_by_name("my-key") is None + + @pytest.mark.asyncio + async def test_find_all(self, repo): + repo._prisma_client.db.litellm_credentialstable._records["k1"] = { + "credential_name": "k1", + "credential_values": {"api_key": "a"}, + "credential_info": {}, + } + repo._prisma_client.db.litellm_credentialstable._records["k2"] = { + "credential_name": "k2", + "credential_values": {"api_key": "b"}, + "credential_info": {}, + } + records = await repo.find_all() + assert len(records) == 2 + + def test_prisma_client_none_raises(self): + repo = CredentialsRepository(None) + with pytest.raises(RuntimeError, match="No DB Connected"): + _ = repo.table + + +class TestConfigRepository: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return ConfigRepository(client) + + def test_deep_merge_dicts_db_wins(self, repo): + dst = {"a": 1, "b": {"c": 2}} + src = {"a": 10, "b": {"d": 3}} + repo._deep_merge_dicts(dst, src) + assert dst["a"] == 10 + assert dst["b"]["c"] == 2 + assert dst["b"]["d"] == 3 + + def test_deep_merge_dicts_skips_none(self, repo): + dst = {"a": 1} + src = {"a": None, "b": 2} + repo._deep_merge_dicts(dst, src) + assert dst["a"] == 1 + assert dst["b"] == 2 + + def test_deep_merge_dicts_skips_empty_list(self, repo): + dst = {"models": ["gpt-4"]} + src = {"models": []} + repo._deep_merge_dicts(dst, src) + assert dst["models"] == ["gpt-4"] + + @pytest.mark.asyncio + async def test_get_param(self, repo): + repo._prisma_client.db.litellm_config._records["general_settings"] = { + "param_name": "general_settings", + "param_value": '{"master_key": "test"}', + } + param = await repo.get_param("general_settings") + assert param is not None + assert param.param_name == "general_settings" + assert param.param_value["master_key"] == "test" + + @pytest.mark.asyncio + async def test_set_param(self, repo): + param = await repo.set_param("test_param", {"key": "value"}) + assert param.param_name == "test_param" + assert param.param_value == {"key": "value"} + + @pytest.mark.asyncio + async def test_delete_param(self, repo): + repo._prisma_client.db.litellm_config._records["test_param"] = { + "param_name": "test_param", + "param_value": "{}", + } + result = await repo.delete_param("test_param") + assert result is True + + @pytest.mark.asyncio + async def test_delete_param_nonexistent(self, repo): + async def mock_delete(where): + raise Exception("Not found") + + repo._prisma_client.db.litellm_config.delete = mock_delete + result = await repo.delete_param("nonexistent") + assert result is False + + @pytest.mark.asyncio + async def test_get_all_params(self, repo): + repo._prisma_client.db.litellm_config._records = { + "param1": {"param_name": "param1", "param_value": '{"a": 1}'}, + "param2": {"param_name": "param2", "param_value": '{"b": 2}'}, + } + params = await repo.get_all_params() + assert len(params) == 2 + + @pytest.mark.asyncio + async def test_reconcile_config_skips_when_store_model_false(self, repo): + yaml_config = {"general_settings": {"key": "value"}} + result = await repo.reconcile_config(yaml_config, store_model_in_db=False) + assert result == yaml_config + + @pytest.mark.asyncio + async def test_prefetch_params(self, repo): + repo._prisma_client.db.litellm_config._records["general_settings"] = { + "param_name": "general_settings", + "param_value": "{}", + } + await repo.prefetch_params(["general_settings"]) + + @pytest.mark.asyncio + async def test_reconcile_config_with_db_values(self, repo): + repo._prisma_client.db.litellm_config._records["general_settings"] = { + "param_name": "general_settings", + "param_value": '{"master_key": "db-key", "db_only": "from_db"}', + } + repo._prisma_client.db.litellm_config._records["router_settings"] = { + "param_name": "router_settings", + "param_value": '{"timeout": 60}', + } + yaml_config = { + "general_settings": {"master_key": "yaml-key", "yaml_only": "from_yaml"}, + } + result = await repo.reconcile_config(yaml_config, store_model_in_db=True) + assert result["general_settings"]["master_key"] == "db-key" + assert result["general_settings"]["yaml_only"] == "from_yaml" + assert result["general_settings"]["db_only"] == "from_db" + assert result["router_settings"]["timeout"] == 60 + + @pytest.mark.asyncio + @patch("litellm.repositories.config_repository.decrypt_value_helper") + async def test_reconcile_config_with_environment_variables( + self, mock_decrypt, repo + ): + mock_decrypt.side_effect = lambda value, **kw: f"decrypted_{value}" + repo._prisma_client.db.litellm_config._records["environment_variables"] = { + "param_name": "environment_variables", + "param_value": '{"api_key": "encrypted_key", "secret": "encrypted_secret"}', + } + yaml_config = {} + result = await repo.reconcile_config(yaml_config, store_model_in_db=True) + assert "environment_variables" in result + assert "api_key" in result["environment_variables"] + assert "API_KEY" in result["environment_variables"] + + @pytest.mark.asyncio + async def test_reconcile_config_none_values_preserved(self, repo): + repo._prisma_client.db.litellm_config._records["general_settings"] = { + "param_name": "general_settings", + "param_value": '{"new_key": "value", "null_key": null}', + } + yaml_config = {"general_settings": {"existing": "keep"}} + result = await repo.reconcile_config(yaml_config, store_model_in_db=True) + assert result["general_settings"]["existing"] == "keep" + assert result["general_settings"]["new_key"] == "value" + + def test_update_config_fields_non_dict(self, repo): + config = {"litellm_settings": "old_value"} + result = repo._update_config_fields( + current_config=config, + param_name="litellm_settings", + db_param_value="new_value", + ) + assert result["litellm_settings"] == "new_value" + + def test_update_config_fields_new_param(self, repo): + config = {} + result = repo._update_config_fields( + current_config=config, + param_name="router_settings", + db_param_value={"timeout": 30}, + ) + assert result["router_settings"] == {"timeout": 30} + + @patch("litellm.repositories.config_repository.decrypt_value_helper") + def test_decrypt_env_variables_non_string(self, mock_decrypt, repo): + mock_decrypt.side_effect = lambda value, **kw: value + env_vars = {"string_val": "encrypted", "int_val": 123, "bool_val": True} + result = repo._decrypt_env_variables(env_vars) + assert result["int_val"] == "123" + assert result["bool_val"] == "True" + + @patch("litellm.repositories.config_repository.decrypt_value_helper") + def test_decrypt_env_variables_none_value(self, mock_decrypt, repo): + mock_decrypt.return_value = None + env_vars = {"key": "value"} + result = repo._decrypt_env_variables(env_vars) + assert "key" not in result + + +class TestVerificationTokenRepositoryExtended: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return VerificationTokenRepository(client) + + @pytest.mark.asyncio + async def test_find_active_tokens(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-active"] = { + "token": "sk-active", + "blocked": False, + "expires": None, + } + tokens = await repo.find_active_tokens() + assert len(tokens) >= 1 + + @pytest.mark.asyncio + async def test_delete_token_with_audit(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-delete"] = { + "token": "sk-delete", + "key_name": "Delete Me", + "spend": 0.0, + } + + class MockTx: + def __init__(self, client): + self.litellm_deletedverificationtoken = ( + client.db.litellm_deletedverificationtoken + ) + self.litellm_verificationtoken = client.db.litellm_verificationtoken + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + repo._prisma_client.db.tx = lambda: MockTx(repo._prisma_client) + deleted = await repo.delete_token( + "sk-delete", + deleted_by="admin", + deleted_by_api_key="sk-admin", + litellm_changed_by="system", + ) + assert deleted is not None + assert deleted.token == "sk-delete" + + @pytest.mark.asyncio + async def test_delete_token_nonexistent(self, repo): + result = await repo.delete_token("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_delete_token_archive_serialization(self, repo): + """Archived token must store JSON columns as strings, map org_id onto the + organization_id column, preserve budget_id, and drop relation-only fields + that don't exist on LiteLLM_DeletedVerificationToken.""" + repo._prisma_client.db.litellm_verificationtoken._records["sk-arch"] = { + "token": "sk-arch", + "key_name": "Archive Me", + "aliases": json.dumps({"a": "b"}), + "metadata": json.dumps({"team": "x"}), + "permissions": json.dumps({"read": True}), + "spend": 5.0, + "organization_id": "org-9", + "budget_id": "budget-9", + "budget_limits": [{"model": "gpt-4", "budget": 1.0}], + } + + class MockTx: + def __init__(self, client): + self.litellm_deletedverificationtoken = ( + client.db.litellm_deletedverificationtoken + ) + self.litellm_verificationtoken = client.db.litellm_verificationtoken + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + repo._prisma_client.db.tx = lambda: MockTx(repo._prisma_client) + + await repo.delete_token("sk-arch", deleted_by="admin") + + archived = list( + repo._prisma_client.db.litellm_deletedverificationtoken._records.values() + )[0] + + assert isinstance(archived["aliases"], str) + assert json.loads(archived["aliases"]) == {"a": "b"} + assert isinstance(archived["metadata"], str) + assert isinstance(archived["permissions"], str) + + assert archived["organization_id"] == "org-9" + assert "org_id" not in archived + + assert archived["budget_id"] == "budget-9" + + for relation_field in ( + "object_permission", + "litellm_budget_table", + "budget_limits", + ): + assert relation_field not in archived + + assert ( + "sk-arch" not in repo._prisma_client.db.litellm_verificationtoken._records + ) + + @pytest.mark.asyncio + async def test_find_by_id_maps_org_and_budget_columns(self, repo): + """Reading a token must surface the organization_id column as org_id and + populate budget_id rather than silently dropping them.""" + repo._prisma_client.db.litellm_verificationtoken._records["sk-read"] = { + "token": "sk-read", + "organization_id": "org-7", + "budget_id": "budget-7", + } + token = await repo.find_by_id("sk-read") + assert token is not None + assert token.org_id == "org-7" + assert token.budget_id == "budget-7" + + @pytest.mark.asyncio + async def test_update_token_all_fields(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-test"] = { + "token": "sk-test", + } + updated = await repo.update_token( + token="sk-test", + updated_by="admin", + key_name="Updated", + key_alias="new-alias", + max_budget=500.0, + expires=datetime(2025, 12, 31), + models=["gpt-4", "gpt-3.5-turbo"], + aliases={"a": "b"}, + config={"c": "d"}, + max_parallel_requests=10, + metadata={"m": "data"}, + tpm_limit=5000, + rpm_limit=50, + budget_duration="daily", + allowed_cache_controls=["cache"], + allowed_routes=["/v1/chat"], + permissions={"write": True}, + blocked=False, + object_permission_id="perm-2", + access_group_ids=["g1", "g2"], + ) + assert updated.key_name == "Updated" + + @pytest.mark.asyncio + async def test_to_model_with_json_fields(self, repo): + repo._prisma_client.db.litellm_verificationtoken._records["sk-json"] = { + "token": "sk-json", + "aliases": '{"alias1": "value1"}', + "config": '{"setting": "val"}', + "permissions": '{"read": true}', + "metadata": '{"key": "value"}', + "model_spend": '{"gpt-4": 10.0}', + "model_max_budget": '{"gpt-4": 100.0}', + "router_settings": '{"timeout": 30}', + "budget_limits": '[{"limit": 50}]', + "litellm_budget_table": '{"budget_id": "b1"}', + } + token = await repo.find_by_id("sk-json") + assert token is not None + assert token.aliases == {"alias1": "value1"} + assert token.config == {"setting": "val"} + + +class TestTeamRepositoryExtended: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return TeamRepository(client) + + @pytest.mark.asyncio + async def test_delete_team_with_audit(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-delete"] = { + "team_id": "team-delete", + "team_alias": "Delete Team", + "members": [], + "admins": [], + "models": [], + "spend": 0.0, + } + + class MockTx: + def __init__(self, client): + self.litellm_deletedteamtable = client.db.litellm_deletedteamtable + self.litellm_teamtable = client.db.litellm_teamtable + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + repo._prisma_client.db.tx = lambda: MockTx(repo._prisma_client) + deleted = await repo.delete_team( + "team-delete", + deleted_by="admin", + deleted_by_api_key="sk-admin", + litellm_changed_by="system", + ) + assert deleted is not None + assert deleted.team_id == "team-delete" + + @pytest.mark.asyncio + async def test_delete_team_nonexistent(self, repo): + result = await repo.delete_team("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_delete_team_with_full_data(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-full"] = { + "team_id": "team-full", + "team_alias": "Full Team", + "organization_id": "org-1", + "object_permission_id": "perm-1", + "members": ["m1", "m2"], + "admins": ["a1"], + "members_with_roles": '[{"user_id": "u1", "role": "admin"}]', + "metadata": '{"key": "value"}', + "max_budget": 1000.0, + "soft_budget": 800.0, + "spend": 150.0, + "models": ["gpt-4"], + "max_parallel_requests": 10, + "tpm_limit": 5000, + "rpm_limit": 50, + "budget_duration": "monthly", + "budget_reset_at": "2025-01-01T00:00:00", + "blocked": True, + "model_spend": '{"gpt-4": 100.0}', + "model_max_budget": '{"gpt-4": 500.0}', + "router_settings": '{"timeout": 30}', + "team_member_permissions": ["read"], + "access_group_ids": ["group-1"], + "policies": ["policy-1"], + "model_id": 42, + "allow_team_guardrail_config": True, + } + + class MockTx: + def __init__(self, client): + self.litellm_deletedteamtable = client.db.litellm_deletedteamtable + self.litellm_teamtable = client.db.litellm_teamtable + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + repo._prisma_client.db.tx = lambda: MockTx(repo._prisma_client) + deleted = await repo.delete_team( + "team-full", + deleted_by="admin", + deleted_by_api_key="sk-admin", + litellm_changed_by="system", + ) + assert deleted is not None + assert deleted.team_id == "team-full" + assert deleted.organization_id == "org-1" + assert deleted.max_budget == 1000.0 + + @pytest.mark.asyncio + async def test_to_model_with_json_fields(self, repo): + repo._prisma_client.db.litellm_teamtable._records["team-json"] = { + "team_id": "team-json", + "metadata": '{"key": "value"}', + "model_spend": '{"gpt-4": 10.0}', + "model_max_budget": '{"gpt-4": 100.0}', + "router_settings": '{"timeout": 30}', + "budget_limits": '[{"budget_duration": "1d", "max_budget": 50.0}]', + "members_with_roles": '[{"user_id": "u1", "role": "admin"}]', + "members": [], + "admins": [], + "models": [], + } + team = await repo.find_by_id("team-json") + assert team is not None + assert team.metadata == {"key": "value"} + assert len(team.members_with_roles) == 1 + assert team.members_with_roles[0].user_id == "u1" + + +class TestUserRepositoryExtended: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return UserRepository(client) + + @pytest.mark.asyncio + async def test_delete_user_simple(self, repo): + repo._prisma_client.db.litellm_usertable._records["user-delete"] = { + "user_id": "user-delete", + "user_email": "delete@example.com", + "teams": [], + "models": [], + "spend": 0.0, + } + deleted = await repo.delete_user("user-delete") + assert deleted is not None + assert deleted.user_id == "user-delete" + + @pytest.mark.asyncio + async def test_delete_user_nonexistent(self, repo): + result = await repo.delete_user("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_update_user_all_fields(self, repo): + repo._prisma_client.db.litellm_usertable._records["user-update"] = { + "user_id": "user-update", + "teams": [], + "models": [], + } + updated = await repo.update_user( + user_id="user-update", + user_alias="newalias", + team_id="team-new", + sso_user_id="sso-new", + organization_id="org-1", + password="new-hashed-pw", + teams=["team-1", "team-2"], + user_role="admin", + max_budget=1000.0, + user_email="new@example.com", + models=["gpt-4"], + metadata={"pref": "dark"}, + max_parallel_requests=20, + tpm_limit=10000, + rpm_limit=100, + budget_duration="monthly", + allowed_cache_controls=["no-cache"], + policies=["policy-1"], + object_permission_id="perm-new", + ) + assert updated.user_email == "new@example.com" + + +class TestProjectRepositoryExtended: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return ProjectRepository(client) + + @pytest.mark.asyncio + async def test_delete_project_simple(self, repo): + repo._prisma_client.db.litellm_projecttable._records["proj-delete"] = { + "project_id": "proj-delete", + "project_alias": "Delete Project", + "spend": 0.0, + } + deleted = await repo.delete_project("proj-delete") + assert deleted is not None + + +class TestBudgetRepositoryExtended: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return BudgetRepository(client) + + @pytest.mark.asyncio + async def test_update_budget_all_fields(self, repo): + repo._prisma_client.db.litellm_budgettable._records["budget-update"] = { + "budget_id": "budget-update", + "max_budget": 100.0, + } + updated = await repo.update_budget( + budget_id="budget-update", + updated_by="admin", + max_budget=500.0, + soft_budget=400.0, + max_parallel_requests=15, + tpm_limit=20000, + rpm_limit=200, + model_max_budget={"gpt-4": 200.0}, + budget_duration="weekly", + allowed_models=["gpt-4", "claude-3"], + ) + assert updated.max_budget == 500.0 + + +class TestModelRepositoryExtended: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return ModelRepository(client) + + @pytest.mark.asyncio + @patch( + "litellm.repositories.model_repository.decrypt_value_helper", + side_effect=lambda value, **kw: value, + ) + async def test_find_by_team_id(self, mock_decrypt, repo): + repo._prisma_client.db.litellm_proxymodeltable._records["model-1"] = { + "model_id": "model-1", + "model_name": "gpt-4", + "litellm_params": '{"api_key": "sk-test"}', + "model_info": '{"team_id": "team-1"}', + "blocked": False, + } + repo._prisma_client.db.litellm_proxymodeltable._records["model-2"] = { + "model_id": "model-2", + "model_name": "claude-3", + "litellm_params": '{"api_key": "sk-other"}', + "model_info": '{"team_id": "team-2"}', + "blocked": False, + } + models = await repo.find_by_team_id("team-1") + assert len(models) == 1 + assert models[0].model_name == "gpt-4" + + +class TestBaseRepositoryExtended: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return BudgetRepository(client) + + @pytest.mark.asyncio + async def test_find_many_with_pagination(self, repo): + repo._prisma_client.db.litellm_budgettable._records = { + "b1": {"budget_id": "b1", "max_budget": 100.0}, + "b2": {"budget_id": "b2", "max_budget": 200.0}, + "b3": {"budget_id": "b3", "max_budget": 300.0}, + } + budgets = await repo.find_many(skip=0, take=2, order={"budget_id": "asc"}) + assert len(budgets) >= 2 + + @pytest.mark.asyncio + async def test_find_many_with_where(self, repo): + repo._prisma_client.db.litellm_budgettable._records = { + "b1": {"budget_id": "b1", "max_budget": 100.0}, + } + budgets = await repo.find_many(where={"budget_id": "b1"}) + assert len(budgets) >= 1 + + @pytest.mark.asyncio + async def test_to_model_list_with_none(self, repo): + result = repo._to_model_list([None, None]) + assert result == [] + + +class _SampleDomainModel(DomainModel): + budget_id: Optional[str] = None + max_budget: Optional[float] = None + + +class TestDomainModelExtended: + def test_from_db_record_none_raises(self): + with pytest.raises(ValueError, match="Cannot create domain model from None"): + DomainModel.from_db_record(None) + + def test_from_db_record_dict(self): + model = _SampleDomainModel.from_db_record( + {"budget_id": "b1", "max_budget": 100.0} + ) + assert model.budget_id == "b1" + + def test_from_db_record_model_dump(self): + class MockRecordWithModelDump: + def model_dump(self): + return {"budget_id": "b2", "max_budget": 200.0} + + model = _SampleDomainModel.from_db_record(MockRecordWithModelDump()) + assert model.budget_id == "b2" + + def test_to_db_dict(self): + model = _SampleDomainModel(budget_id="b3", max_budget=300.0) + data = model.to_db_dict() + assert data["budget_id"] == "b3" + assert data["max_budget"] == 300.0 + + +class TestTeamRepositoryArchiveData: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return TeamRepository(client) + + def test_build_archive_data_minimal_fields(self, repo): + + team = LiteLLM_TeamTable(team_id="team-minimal") + archive_data = repo._build_archive_data(team) + assert archive_data["team_id"] == "team-minimal" + assert archive_data["admins"] == [] + assert archive_data["members"] == [] + assert archive_data["models"] == [] + assert archive_data["spend"] == 0.0 + assert archive_data["blocked"] is False + assert "team_alias" not in archive_data + assert "organization_id" not in archive_data + assert "object_permission_id" not in archive_data + assert "members_with_roles" not in archive_data + assert "metadata" not in archive_data + assert "max_budget" not in archive_data + assert "soft_budget" not in archive_data + assert "max_parallel_requests" not in archive_data + assert "tpm_limit" not in archive_data + assert "rpm_limit" not in archive_data + assert "budget_duration" not in archive_data + assert "budget_reset_at" not in archive_data + assert "model_spend" not in archive_data + assert "model_max_budget" not in archive_data + assert "router_settings" not in archive_data + assert "model_id" not in archive_data + + def test_build_archive_data_excludes_invalid_columns(self, repo): + + team = LiteLLM_TeamTable( + team_id="team-1", + team_alias="My Team", + admins=["admin1"], + members=["member1"], + models=["gpt-4"], + default_team_member_models=["gpt-3.5-turbo"], + ) + archive_data = repo._build_archive_data(team) + assert "default_team_member_models" not in archive_data + assert "budget_limits" not in archive_data + assert archive_data["team_id"] == "team-1" + assert archive_data["team_alias"] == "My Team" + assert archive_data["admins"] == ["admin1"] + assert archive_data["members"] == ["member1"] + assert archive_data["models"] == ["gpt-4"] + + def test_build_archive_data_with_all_valid_fields(self, repo): + from datetime import datetime + + from litellm.models.team import Member + + team = LiteLLM_TeamTable( + team_id="team-full", + team_alias="Full Team", + organization_id="org-1", + object_permission_id="perm-1", + admins=["admin1", "admin2"], + members=["m1", "m2"], + members_with_roles=[Member(user_id="u1", role="admin")], + metadata={"key": "value"}, + max_budget=1000.0, + soft_budget=800.0, + spend=150.0, + models=["gpt-4", "claude-3"], + max_parallel_requests=10, + tpm_limit=5000, + rpm_limit=50, + budget_duration="monthly", + budget_reset_at=datetime(2025, 1, 1), + blocked=True, + model_spend={"gpt-4": 100.0}, + model_max_budget={"gpt-4": 500.0}, + router_settings={"timeout": 30}, + team_member_permissions=["read"], + access_group_ids=["group-1"], + policies=["policy-1"], + model_id=42, + allow_team_guardrail_config=True, + ) + archive_data = repo._build_archive_data(team) + assert archive_data["team_id"] == "team-full" + assert archive_data["organization_id"] == "org-1" + assert archive_data["object_permission_id"] == "perm-1" + assert archive_data["max_budget"] == 1000.0 + assert archive_data["soft_budget"] == 800.0 + assert archive_data["spend"] == 150.0 + assert archive_data["blocked"] is True + assert archive_data["model_id"] == 42 + assert archive_data["allow_team_guardrail_config"] is True + assert "members_with_roles" in archive_data + assert "metadata" in archive_data + assert "model_spend" in archive_data + assert "model_max_budget" in archive_data + assert "router_settings" in archive_data + + +class TestConfigRepositoryDeepCopy: + @pytest.fixture + def repo(self): + client = MockPrismaClient() + return ConfigRepository(client) + + @pytest.mark.asyncio + async def test_reconcile_config_does_not_mutate_original(self, repo): + import copy + + repo._prisma_client.db.litellm_config._records["general_settings"] = { + "param_name": "general_settings", + "param_value": '{"db_key": "db_value", "nested": {"db_nested": "from_db"}}', + } + original_config = { + "general_settings": { + "yaml_key": "yaml_value", + "nested": {"yaml_nested": "from_yaml"}, + } + } + original_copy = copy.deepcopy(original_config) + result = await repo.reconcile_config(original_config, store_model_in_db=True) + assert original_config == original_copy + assert result["general_settings"]["db_key"] == "db_value" + assert result["general_settings"]["yaml_key"] == "yaml_value" + assert result["general_settings"]["nested"]["db_nested"] == "from_db" + assert result["general_settings"]["nested"]["yaml_nested"] == "from_yaml" + + @pytest.mark.asyncio + async def test_reconcile_config_repeated_calls_independent(self, repo): + repo._prisma_client.db.litellm_config._records["general_settings"] = { + "param_name": "general_settings", + "param_value": '{"db_key": "db_value"}', + } + yaml_config = {"general_settings": {"yaml_key": "yaml_value"}} + result1 = await repo.reconcile_config(yaml_config, store_model_in_db=True) + result1["general_settings"]["modified"] = "in_result1" + result2 = await repo.reconcile_config(yaml_config, store_model_in_db=True) + assert "modified" not in yaml_config.get("general_settings", {}) + assert "modified" not in result2.get("general_settings", {}) + + +class TestPrismaTableRepository: + def test_table_property_returns_named_delegate(self): + from litellm.repositories.table_repositories import ( + AgentsRepository, + PolicyRepository, + ) + + prisma_client = MagicMock() + agents = AgentsRepository(prisma_client) + policy = PolicyRepository(prisma_client) + + assert agents.table is prisma_client.db.litellm_agentstable + assert policy.table is prisma_client.db.litellm_policytable + assert agents.table is not policy.table + + def test_table_access_raises_without_db(self): + from litellm.repositories.table_repositories import SpendLogsRepository + + repo = SpendLogsRepository(None) + with pytest.raises(RuntimeError, match="No DB Connected"): + _ = repo.table + + def test_each_repository_binds_its_own_table_name(self): + import litellm.repositories.table_repositories as tr + + prisma_client = MagicMock() + repos = [ + obj + for name, obj in vars(tr).items() + if isinstance(obj, type) + and issubclass(obj, tr.PrismaTableRepository) + and obj is not tr.PrismaTableRepository + ] + assert len(repos) >= 40 + seen = set() + for repo_cls in repos: + name = repo_cls.table_name + assert name.startswith("litellm_") + assert name not in seen, f"duplicate table_name {name}" + seen.add(name) + assert repo_cls(prisma_client).table is getattr(prisma_client.db, name) diff --git a/ui/litellm-dashboard/src/lib/http/schema.d.ts b/ui/litellm-dashboard/src/lib/http/schema.d.ts index f952bbfbba..403e00f953 100644 --- a/ui/litellm-dashboard/src/lib/http/schema.d.ts +++ b/ui/litellm-dashboard/src/lib/http/schema.d.ts @@ -21364,8 +21364,7 @@ export interface components { }; /** * LiteLLM_DeletedTeamTable - * @description Recording of deleted teams for audit purposes. Mirrors LiteLLM_TeamTable - * plus metadata captured at deletion time. + * @description Audit record for deleted teams; mirrors the team plus deletion metadata. */ LiteLLM_DeletedTeamTable: { /** Access Group Ids */ @@ -21375,6 +21374,11 @@ export interface components { * @default [] */ admins: unknown[]; + /** + * Allow Team Guardrail Config + * @default false + */ + allow_team_guardrail_config: boolean | null; /** * Blocked * @default false @@ -21421,6 +21425,20 @@ export interface components { } | null; /** Model Id */ model_id?: number | null; + /** + * Model Max Budget + * @default {} + */ + model_max_budget: { + [key: string]: unknown; + } | null; + /** + * Model Spend + * @default {} + */ + model_spend: { + [key: string]: unknown; + } | null; /** * Models * @default [] @@ -21431,6 +21449,8 @@ export interface components { object_permission_id?: string | null; /** Organization Id */ organization_id?: string | null; + /** Policies */ + policies?: string[] | null; /** Router Settings */ router_settings?: { [key: string]: unknown; @@ -21454,8 +21474,7 @@ export interface components { }; /** * LiteLLM_DeletedVerificationToken - * @description Recording of deleted keys for audit purposes. Mirrors LiteLLM_VerificationToken - * plus metadata captured at deletion time. + * @description Audit record for deleted keys; mirrors the token plus deletion metadata. */ LiteLLM_DeletedVerificationToken: { /** Access Group Ids */ @@ -21488,6 +21507,8 @@ export interface components { blocked?: boolean | null; /** Budget Duration */ budget_duration?: string | null; + /** Budget Id */ + budget_id?: string | null; /** Budget Limits */ budget_limits?: { [key: string]: unknown; @@ -21866,9 +21887,9 @@ export interface components { /** Id */ id?: number | null; /** Model Aliases */ - model_aliases?: { + model_aliases?: string | { [key: string]: unknown; - } | string | null; + } | null; team?: components["schemas"]["LiteLLM_TeamTable"] | null; /** Updated By */ updated_by: string; @@ -21934,6 +21955,11 @@ export interface components { } | null; /** Mcp Toolsets */ mcp_toolsets?: string[] | null; + /** + * Models + * @default [] + */ + models: string[] | null; /** Object Permission Id */ object_permission_id: string; /** @@ -21949,7 +21975,7 @@ export interface components { }; /** * LiteLLM_OrganizationMembershipTable - * @description This is the table that track what organizations a user belongs to and users spend within the organization + * @description Tracks which organizations a user belongs to and their spend within it. */ LiteLLM_OrganizationMembershipTable: { /** Budget Id */ @@ -22005,7 +22031,17 @@ export interface components { metadata?: { [key: string]: unknown; } | null; - /** Models */ + /** + * Model Spend + * @default {} + */ + model_spend: { + [key: string]: unknown; + } | null; + /** + * Models + * @default [] + */ models: string[]; object_permission?: components["schemas"]["LiteLLM_ObjectPermissionTable"] | null; /** Object Permission Id */ @@ -22231,7 +22267,7 @@ export interface components { [key: string]: unknown; } | null; /** Stream Timeout */ - stream_timeout?: string | number | null; + stream_timeout?: number | string | null; /** Tag Regex */ tag_regex?: string[] | null; /** Tags */ @@ -22286,7 +22322,7 @@ export interface components { /** Created At */ created_at?: string | null; /** Created By */ - created_by: string; + created_by?: string | null; /** Description */ description?: string | null; litellm_budget_table?: components["schemas"]["LiteLLM_BudgetTable"] | null; @@ -22328,7 +22364,7 @@ export interface components { /** Updated At */ updated_at?: string | null; /** Updated By */ - updated_by: string; + updated_by?: string | null; }; /** LiteLLM_SpendLogs */ LiteLLM_SpendLogs: { @@ -22432,6 +22468,11 @@ export interface components { * @default [] */ admins: unknown[]; + /** + * Allow Team Guardrail Config + * @default false + */ + allow_team_guardrail_config: boolean | null; /** * Blocked * @default false @@ -22468,6 +22509,20 @@ export interface components { } | null; /** Model Id */ model_id?: number | null; + /** + * Model Max Budget + * @default {} + */ + model_max_budget: { + [key: string]: unknown; + } | null; + /** + * Model Spend + * @default {} + */ + model_spend: { + [key: string]: unknown; + } | null; /** * Models * @default [] @@ -22478,6 +22533,8 @@ export interface components { object_permission_id?: string | null; /** Organization Id */ organization_id?: string | null; + /** Policies */ + policies?: string[] | null; /** Router Settings */ router_settings?: { [key: string]: unknown; @@ -22549,6 +22606,11 @@ export interface components { }; /** LiteLLM_UserTable */ LiteLLM_UserTable: { + /** + * Allowed Cache Controls + * @default [] + */ + allowed_cache_controls: string[]; /** Budget Duration */ budget_duration?: string | null; /** Budget Reset At */ @@ -22557,6 +22619,8 @@ export interface components { created_at?: string | null; /** Max Budget */ max_budget?: number | null; + /** Max Parallel Requests */ + max_parallel_requests?: number | null; /** Metadata */ metadata?: { [key: string]: unknown; @@ -22581,8 +22645,17 @@ export interface components { */ models: unknown[]; object_permission?: components["schemas"]["LiteLLM_ObjectPermissionTable"] | null; + /** Object Permission Id */ + object_permission_id?: string | null; + /** Organization Id */ + organization_id?: string | null; /** Organization Memberships */ organization_memberships?: components["schemas"]["LiteLLM_OrganizationMembershipTable"][] | null; + /** + * Policies + * @default [] + */ + policies: string[]; /** Rpm Limit */ rpm_limit?: number | null; /** @@ -22592,6 +22665,8 @@ export interface components { spend: number; /** Sso User Id */ sso_user_id?: string | null; + /** Team Id */ + team_id?: string | null; /** * Teams * @default [] @@ -22612,6 +22687,11 @@ export interface components { }; /** LiteLLM_UserTableWithKeyCount */ LiteLLM_UserTableWithKeyCount: { + /** + * Allowed Cache Controls + * @default [] + */ + allowed_cache_controls: string[]; /** Budget Duration */ budget_duration?: string | null; /** Budget Reset At */ @@ -22625,6 +22705,8 @@ export interface components { key_count: number; /** Max Budget */ max_budget?: number | null; + /** Max Parallel Requests */ + max_parallel_requests?: number | null; /** Metadata */ metadata?: { [key: string]: unknown; @@ -22649,8 +22731,17 @@ export interface components { */ models: unknown[]; object_permission?: components["schemas"]["LiteLLM_ObjectPermissionTable"] | null; + /** Object Permission Id */ + object_permission_id?: string | null; + /** Organization Id */ + organization_id?: string | null; /** Organization Memberships */ organization_memberships?: components["schemas"]["LiteLLM_OrganizationMembershipTable"][] | null; + /** + * Policies + * @default [] + */ + policies: string[]; /** Rpm Limit */ rpm_limit?: number | null; /** @@ -22660,6 +22751,8 @@ export interface components { spend: number; /** Sso User Id */ sso_user_id?: string | null; + /** Team Id */ + team_id?: string | null; /** * Teams * @default [] @@ -22710,6 +22803,8 @@ export interface components { blocked?: boolean | null; /** Budget Duration */ budget_duration?: string | null; + /** Budget Id */ + budget_id?: string | null; /** Budget Limits */ budget_limits?: { [key: string]: unknown; @@ -24234,7 +24329,17 @@ export interface components { metadata?: { [key: string]: unknown; } | null; - /** Models */ + /** + * Model Spend + * @default {} + */ + model_spend: { + [key: string]: unknown; + } | null; + /** + * Models + * @default [] + */ models: string[]; object_permission?: components["schemas"]["LiteLLM_ObjectPermissionTable"] | null; /** Object Permission Id */ @@ -24339,7 +24444,7 @@ export interface components { */ created_at: string; /** Created By */ - created_by: string; + created_by?: string | null; /** Description */ description?: string | null; litellm_budget_table?: components["schemas"]["LiteLLM_BudgetTable"] | null; @@ -24384,7 +24489,7 @@ export interface components { */ updated_at: string; /** Updated By */ - updated_by: string; + updated_by?: string | null; }; /** NewTeamRequest */ NewTeamRequest: { @@ -27272,6 +27377,11 @@ export interface components { * @default [] */ admins: unknown[]; + /** + * Allow Team Guardrail Config + * @default false + */ + allow_team_guardrail_config: boolean | null; /** * Blocked * @default false @@ -27308,6 +27418,20 @@ export interface components { } | null; /** Model Id */ model_id?: number | null; + /** + * Model Max Budget + * @default {} + */ + model_max_budget: { + [key: string]: unknown; + } | null; + /** + * Model Spend + * @default {} + */ + model_spend: { + [key: string]: unknown; + } | null; /** * Models * @default [] @@ -27318,6 +27442,8 @@ export interface components { object_permission_id?: string | null; /** Organization Id */ organization_id?: string | null; + /** Policies */ + policies?: string[] | null; /** Router Settings */ router_settings?: { [key: string]: unknown; @@ -27361,6 +27487,11 @@ export interface components { * @default [] */ admins: unknown[]; + /** + * Allow Team Guardrail Config + * @default false + */ + allow_team_guardrail_config: boolean | null; /** * Blocked * @default false @@ -27407,6 +27538,20 @@ export interface components { } | null; /** Model Id */ model_id?: number | null; + /** + * Model Max Budget + * @default {} + */ + model_max_budget: { + [key: string]: unknown; + } | null; + /** + * Model Spend + * @default {} + */ + model_spend: { + [key: string]: unknown; + } | null; /** * Models * @default [] @@ -27417,6 +27562,8 @@ export interface components { object_permission_id?: string | null; /** Organization Id */ organization_id?: string | null; + /** Policies */ + policies?: string[] | null; /** Router Settings */ router_settings?: { [key: string]: unknown; @@ -28907,6 +29054,8 @@ export interface components { blocked?: boolean | null; /** Budget Duration */ budget_duration?: string | null; + /** Budget Id */ + budget_id?: string | null; /** Budget Limits */ budget_limits?: { [key: string]: unknown; @@ -29653,7 +29802,7 @@ export interface components { [key: string]: unknown; } | null; /** Stream Timeout */ - stream_timeout?: string | number | null; + stream_timeout?: number | string | null; /** Tag Regex */ tag_regex?: string[] | null; /** Tags */