mirror of
https://github.com/tiennm99/litellm.git
synced 2026-07-03 09:10:47 +00:00
encrypt callback_vars in key/team metadata at rest (#27141)
Co-authored-by: Michael Riad Zaky <michaelr@Michaels-MacBook-Air.local> Co-authored-by: Yuneng Jiang <yuneng@berri.ai>
This commit is contained in:
@@ -1,16 +1,31 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.types_utils.utils import get_instance_fn
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingGuardrailInformation,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
_CALLBACK_VAR_MASKER = SensitiveDataMasker()
|
||||
# Compound names that are credential-bearing but don't contain any of the
|
||||
# default sensitive segments (so SensitiveDataMasker won't flag them).
|
||||
_EXTRA_SENSITIVE_CALLBACK_KEYS = {"gcs_path_service_account"}
|
||||
# Sentinel prefix on encrypted callback_var values. Lets us detect
|
||||
# already-encrypted input cheaply (no decrypt-attempt round trip) and
|
||||
# avoid double-encrypting if `LITELLM_SALT_KEY` is rotated between writes.
|
||||
_CALLBACK_VAR_ENCRYPTED_PREFIX = "litellm_enc::"
|
||||
|
||||
blue_color_code = "\033[94m"
|
||||
reset_color_code = "\033[0m"
|
||||
|
||||
@@ -550,3 +565,85 @@ def normalize_callback_names(callbacks: Iterable[Any]) -> List[Any]:
|
||||
if callbacks is None:
|
||||
return []
|
||||
return [c.lower() if isinstance(c, str) else c for c in callbacks]
|
||||
|
||||
|
||||
def encrypt_callback_vars(metadata: Any) -> Any:
|
||||
"""Return a deep copy of metadata with callback_vars values encrypted at rest.
|
||||
|
||||
Idempotent: a value that already decrypts cleanly is left unchanged so
|
||||
round-trips through edit forms don't double-encrypt.
|
||||
"""
|
||||
return _transform_callback_vars(metadata, _encrypt_if_plaintext)
|
||||
|
||||
|
||||
def decrypt_callback_vars(metadata: Any) -> Any:
|
||||
"""Return a deep copy of metadata with callback_vars values decrypted.
|
||||
|
||||
Legacy plaintext rows pass through unchanged (decrypt failure → original).
|
||||
"""
|
||||
return _transform_callback_vars(metadata, _decrypt_or_passthrough)
|
||||
|
||||
|
||||
def _transform_callback_vars(
|
||||
metadata: Any, transform: Callable[[str, Any], Any]
|
||||
) -> Any:
|
||||
if not isinstance(metadata, dict):
|
||||
return metadata
|
||||
out = copy.deepcopy(metadata)
|
||||
logging_entries = out.get("logging")
|
||||
if isinstance(logging_entries, list):
|
||||
for entry in logging_entries:
|
||||
if isinstance(entry, dict) and isinstance(entry.get("callback_vars"), dict):
|
||||
entry["callback_vars"] = {
|
||||
k: transform(k, v) for k, v in entry["callback_vars"].items()
|
||||
}
|
||||
callback_settings = out.get("callback_settings")
|
||||
if isinstance(callback_settings, dict) and isinstance(
|
||||
callback_settings.get("callback_vars"), dict
|
||||
):
|
||||
callback_settings["callback_vars"] = {
|
||||
k: transform(k, v) for k, v in callback_settings["callback_vars"].items()
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def _is_sensitive_callback_var(key: str) -> bool:
|
||||
"""Match codebase precedent: only credential-bearing fields get encrypted;
|
||||
routing/identifier fields (host, base_url, project, region) stay plain."""
|
||||
if key in _EXTRA_SENSITIVE_CALLBACK_KEYS:
|
||||
return True
|
||||
return _CALLBACK_VAR_MASKER.is_sensitive_key(key)
|
||||
|
||||
|
||||
def _encrypt_if_plaintext(key: str, value: Any) -> Any:
|
||||
if not isinstance(value, str) or not value:
|
||||
return value
|
||||
if not _is_sensitive_callback_var(key):
|
||||
return value
|
||||
if value.startswith(_CALLBACK_VAR_ENCRYPTED_PREFIX):
|
||||
# Already encrypted — round-tripping ciphertext (e.g. UI Edit Settings
|
||||
# save without changing the field) must not double-encrypt. Cheap
|
||||
# prefix check is robust under salt-key rotation; a decrypt-based
|
||||
# idempotency check would mis-classify K1-encrypted blobs as
|
||||
# plaintext under K2 and wrap them a second time.
|
||||
return value
|
||||
try:
|
||||
return _CALLBACK_VAR_ENCRYPTED_PREFIX + encrypt_value_helper(value)
|
||||
except Exception:
|
||||
# No salt key / master key configured — leave the value as-is rather
|
||||
# than crash the write. Dev environments without LITELLM_SALT_KEY hit
|
||||
# this path; production always has a master key so encryption proceeds.
|
||||
return value
|
||||
|
||||
|
||||
def _decrypt_or_passthrough(key: str, value: Any) -> Any:
|
||||
if not isinstance(value, str) or not value:
|
||||
return value
|
||||
if not value.startswith(_CALLBACK_VAR_ENCRYPTED_PREFIX):
|
||||
# Legacy plaintext rows or non-credential fields — return as-is.
|
||||
return value
|
||||
inner = value[len(_CALLBACK_VAR_ENCRYPTED_PREFIX) :]
|
||||
decrypted = decrypt_value_helper(
|
||||
value=inner, key=key, exception_type="debug", return_original_value=False
|
||||
)
|
||||
return decrypted if decrypted is not None else value
|
||||
|
||||
@@ -26,6 +26,7 @@ from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
decrypt_callback_vars,
|
||||
get_metadata_variable_name_from_kwargs,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
|
||||
@@ -477,7 +478,7 @@ class KeyAndTeamLoggingSettings:
|
||||
user_api_key_dict.metadata is not None
|
||||
and "logging" in user_api_key_dict.metadata
|
||||
):
|
||||
return user_api_key_dict.metadata["logging"]
|
||||
return decrypt_callback_vars(user_api_key_dict.metadata).get("logging")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -486,7 +487,7 @@ class KeyAndTeamLoggingSettings:
|
||||
user_api_key_dict.team_metadata is not None
|
||||
and "logging" in user_api_key_dict.team_metadata
|
||||
):
|
||||
return user_api_key_dict.team_metadata["logging"]
|
||||
return decrypt_callback_vars(user_api_key_dict.team_metadata).get("logging")
|
||||
return None
|
||||
|
||||
|
||||
@@ -540,7 +541,7 @@ def _get_dynamic_logging_metadata(
|
||||
}
|
||||
}
|
||||
"""
|
||||
team_metadata = user_api_key_dict.team_metadata
|
||||
team_metadata = decrypt_callback_vars(user_api_key_dict.team_metadata)
|
||||
callback_settings = team_metadata.get("callback_settings", None) or {}
|
||||
callback_settings_obj = TeamCallbackMetadata(**callback_settings)
|
||||
verbose_proxy_logger.debug(
|
||||
|
||||
@@ -51,6 +51,10 @@ from litellm.proxy.auth.auth_checks import (
|
||||
)
|
||||
from litellm.proxy.auth.auth_utils import abbreviate_api_key
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
decrypt_callback_vars,
|
||||
encrypt_callback_vars,
|
||||
)
|
||||
from litellm.proxy.common_utils.rbac_utils import check_org_admin_can_generate_keys
|
||||
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
|
||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||
@@ -1752,7 +1756,7 @@ def prepare_metadata_fields(
|
||||
)
|
||||
)
|
||||
|
||||
non_default_values["metadata"] = casted_metadata
|
||||
non_default_values["metadata"] = encrypt_callback_vars(casted_metadata)
|
||||
return non_default_values
|
||||
|
||||
|
||||
@@ -3459,6 +3463,7 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
||||
metadata = metadata or {}
|
||||
metadata["prompts"] = prompts
|
||||
|
||||
metadata = encrypt_callback_vars(metadata)
|
||||
metadata_json = json.dumps(metadata)
|
||||
validate_model_max_budget(model_max_budget)
|
||||
model_max_budget_json = json.dumps(model_max_budget)
|
||||
@@ -5942,7 +5947,7 @@ async def key_health(
|
||||
logging_statuses = await test_key_logging(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request=request,
|
||||
key_logging=key_metadata["logging"],
|
||||
key_logging=decrypt_callback_vars(key_metadata)["logging"],
|
||||
)
|
||||
health_status["logging_callbacks"] = logging_statuses
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars
|
||||
from litellm.proxy.management_endpoints.team_endpoints import _verify_team_access
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
|
||||
@@ -245,6 +246,7 @@ async def add_team_callbacks(
|
||||
team_callback_settings.append(data.model_dump())
|
||||
|
||||
team_metadata["logging"] = team_callback_settings
|
||||
team_metadata = encrypt_callback_vars(team_metadata)
|
||||
team_metadata_json = json.dumps(team_metadata) # update team_metadata
|
||||
|
||||
new_team_row = await prisma_client.db.litellm_teamtable.update(
|
||||
@@ -347,6 +349,7 @@ async def disable_team_logging(
|
||||
|
||||
# Update metadata
|
||||
team_metadata["callback_settings"] = team_callback_settings_obj.model_dump()
|
||||
team_metadata = encrypt_callback_vars(team_metadata)
|
||||
team_metadata_json = json.dumps(team_metadata)
|
||||
|
||||
# Update team in database
|
||||
|
||||
@@ -72,6 +72,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||
get_user_object,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars
|
||||
from litellm.proxy.management_endpoints.common_utils import (
|
||||
_check_passthrough_routes_caller_permission,
|
||||
_is_user_org_admin_for_team,
|
||||
@@ -1156,6 +1157,11 @@ async def new_team( # noqa: PLR0915
|
||||
)
|
||||
complete_team_data_dict["router_settings"] = router_settings_json
|
||||
|
||||
if complete_team_data_dict.get("metadata") is not None:
|
||||
complete_team_data_dict["metadata"] = encrypt_callback_vars(
|
||||
complete_team_data_dict["metadata"]
|
||||
)
|
||||
|
||||
complete_team_data_dict = prisma_client.jsonify_team_object(
|
||||
db_data=complete_team_data_dict
|
||||
)
|
||||
@@ -1828,6 +1834,9 @@ async def update_team( # noqa: PLR0915
|
||||
# update team metadata fields
|
||||
_update_metadata_fields(updated_kv=updated_kv)
|
||||
|
||||
if updated_kv.get("metadata") is not None:
|
||||
updated_kv["metadata"] = encrypt_callback_vars(updated_kv["metadata"])
|
||||
|
||||
if "model_aliases" in updated_kv:
|
||||
updated_kv.pop("model_aliases")
|
||||
_model_id = await _update_model_table(
|
||||
|
||||
@@ -416,9 +416,10 @@ def test_dynamic_turn_off_message_logging(callback_vars):
|
||||
)
|
||||
|
||||
assert callbacks is not None
|
||||
assert (
|
||||
callbacks.callback_vars["turn_off_message_logging"]
|
||||
== callback_vars["turn_off_message_logging"]
|
||||
# AddTeamCallback's validator stringifies callback_var values, so compare
|
||||
# against the str() of the input rather than the input bool directly.
|
||||
assert callbacks.callback_vars["turn_off_message_logging"] == str(
|
||||
callback_vars["turn_off_message_logging"]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import sys
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
@@ -7,6 +8,8 @@ sys.path.insert(
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
decrypt_callback_vars,
|
||||
encrypt_callback_vars,
|
||||
initialize_callbacks_on_proxy,
|
||||
get_remaining_tokens_and_requests_from_request_data,
|
||||
normalize_callback_names,
|
||||
@@ -119,3 +122,143 @@ def test_initialize_callbacks_on_proxy_instantiates_compression_interception(
|
||||
assert "compression_interception" not in litellm.callbacks
|
||||
finally:
|
||||
litellm.callbacks = original_callbacks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# encrypt_callback_vars / decrypt_callback_vars
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sample_metadata():
|
||||
return {
|
||||
"logging": [
|
||||
{
|
||||
"callback_name": "langfuse",
|
||||
"callback_type": "success_and_failure",
|
||||
"callback_vars": {
|
||||
"langfuse_public_key": "pk-lf-public",
|
||||
"langfuse_secret_key": "sk-lf-secret",
|
||||
"langfuse_host": "https://cloud.langfuse.com",
|
||||
},
|
||||
}
|
||||
],
|
||||
"callback_settings": {
|
||||
"callback_vars": {"langsmith_api_key": "ls-api-key"},
|
||||
},
|
||||
"tags": ["unrelated"],
|
||||
}
|
||||
|
||||
|
||||
def _set_salt_key(monkeypatch):
|
||||
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
|
||||
|
||||
|
||||
def test_encrypt_callback_vars_round_trip(monkeypatch):
|
||||
_set_salt_key(monkeypatch)
|
||||
original = _sample_metadata()
|
||||
encrypted = encrypt_callback_vars(original)
|
||||
|
||||
enc_vars = encrypted["logging"][0]["callback_vars"]
|
||||
assert enc_vars["langfuse_secret_key"] != "sk-lf-secret"
|
||||
assert enc_vars["langfuse_public_key"] != "pk-lf-public"
|
||||
assert (
|
||||
encrypted["callback_settings"]["callback_vars"]["langsmith_api_key"]
|
||||
!= "ls-api-key"
|
||||
)
|
||||
|
||||
decrypted = decrypt_callback_vars(encrypted)
|
||||
assert (
|
||||
decrypted["logging"][0]["callback_vars"]
|
||||
== original["logging"][0]["callback_vars"]
|
||||
)
|
||||
assert (
|
||||
decrypted["callback_settings"]["callback_vars"]
|
||||
== original["callback_settings"]["callback_vars"]
|
||||
)
|
||||
|
||||
|
||||
def test_encrypt_callback_vars_is_idempotent(monkeypatch):
|
||||
_set_salt_key(monkeypatch)
|
||||
once = encrypt_callback_vars(_sample_metadata())
|
||||
twice = encrypt_callback_vars(once)
|
||||
assert once == twice
|
||||
|
||||
|
||||
def test_encrypt_callback_vars_does_not_mutate_input(monkeypatch):
|
||||
_set_salt_key(monkeypatch)
|
||||
original = _sample_metadata()
|
||||
snapshot = copy.deepcopy(original)
|
||||
encrypt_callback_vars(original)
|
||||
assert original == snapshot
|
||||
|
||||
|
||||
def test_decrypt_callback_vars_passes_through_legacy_plaintext(monkeypatch):
|
||||
_set_salt_key(monkeypatch)
|
||||
plaintext = _sample_metadata()
|
||||
decrypted = decrypt_callback_vars(plaintext)
|
||||
# legacy rows decrypt-fail and fall through unchanged
|
||||
assert (
|
||||
decrypted["logging"][0]["callback_vars"]["langfuse_secret_key"]
|
||||
== "sk-lf-secret"
|
||||
)
|
||||
|
||||
|
||||
def test_callback_vars_helpers_handle_edge_shapes(monkeypatch):
|
||||
_set_salt_key(monkeypatch)
|
||||
assert encrypt_callback_vars(None) is None
|
||||
assert encrypt_callback_vars({}) == {}
|
||||
assert decrypt_callback_vars(None) is None
|
||||
assert decrypt_callback_vars({}) == {}
|
||||
|
||||
# logging not a list / callback_vars not a dict — leave alone
|
||||
weird = {"logging": "not-a-list", "callback_settings": {"callback_vars": None}}
|
||||
assert encrypt_callback_vars(weird) == weird
|
||||
|
||||
# empty/None callback_vars values stay as-is
|
||||
has_blanks = {
|
||||
"logging": [
|
||||
{
|
||||
"callback_vars": {
|
||||
"langfuse_public_key": "",
|
||||
"langfuse_secret_key": None,
|
||||
"langfuse_host": "https://cloud.langfuse.com",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
out = encrypt_callback_vars(has_blanks)
|
||||
cv = out["logging"][0]["callback_vars"]
|
||||
assert cv["langfuse_public_key"] == ""
|
||||
assert cv["langfuse_secret_key"] is None
|
||||
# langfuse_host is a routing field, not a credential — stays plain.
|
||||
assert cv["langfuse_host"] == "https://cloud.langfuse.com"
|
||||
|
||||
|
||||
def test_encrypt_callback_vars_only_encrypts_credential_fields(monkeypatch):
|
||||
"""Routing/identifier fields stay plaintext; credential fields encrypt."""
|
||||
_set_salt_key(monkeypatch)
|
||||
metadata = {
|
||||
"logging": [
|
||||
{
|
||||
"callback_vars": {
|
||||
"langfuse_secret_key": "sk-real",
|
||||
"langfuse_public_key": "pk-real",
|
||||
"langfuse_host": "https://cloud.langfuse.com",
|
||||
"langsmith_project": "my-proj",
|
||||
"langsmith_base_url": "https://smith.example",
|
||||
"gcs_path_service_account": "{json contents}",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
cv = encrypt_callback_vars(metadata)["logging"][0]["callback_vars"]
|
||||
|
||||
# Sensitive (key-name segments match SensitiveDataMasker patterns):
|
||||
assert cv["langfuse_secret_key"] != "sk-real"
|
||||
assert cv["langfuse_public_key"] != "pk-real"
|
||||
# Sensitive via the explicit gcs override:
|
||||
assert cv["gcs_path_service_account"] != "{json contents}"
|
||||
# Routing / identifiers stay plaintext:
|
||||
assert cv["langfuse_host"] == "https://cloud.langfuse.com"
|
||||
assert cv["langsmith_project"] == "my-proj"
|
||||
assert cv["langsmith_base_url"] == "https://smith.example"
|
||||
|
||||
@@ -1336,6 +1336,39 @@ async def test_update_without_metadata_still_preserves_existing():
|
||||
assert result["metadata"]["other"] == "kept"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_key_update_data_encrypts_callback_vars(monkeypatch):
|
||||
"""/key/update must encrypt callback_vars values before they reach the DB."""
|
||||
from litellm.proxy.common_utils.callback_utils import decrypt_callback_vars
|
||||
|
||||
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
|
||||
data = UpdateKeyRequest(
|
||||
key="sk-1",
|
||||
metadata={
|
||||
"logging": [
|
||||
{
|
||||
"callback_name": "langfuse",
|
||||
"callback_type": "success",
|
||||
"callback_vars": {
|
||||
"langfuse_public_key": "pk-real",
|
||||
"langfuse_secret_key": "sk-real",
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
existing_key = LiteLLM_VerificationToken(token="hashed")
|
||||
|
||||
result = await prepare_key_update_data(data=data, existing_key_row=existing_key)
|
||||
|
||||
cv = result["metadata"]["logging"][0]["callback_vars"]
|
||||
assert cv["langfuse_secret_key"] != "sk-real"
|
||||
assert cv["langfuse_public_key"] != "pk-real"
|
||||
recovered = decrypt_callback_vars(result["metadata"])["logging"][0]["callback_vars"]
|
||||
assert recovered["langfuse_secret_key"] == "sk-real"
|
||||
assert recovered["langfuse_public_key"] == "pk-real"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_key_update_data_duration_never_expires():
|
||||
"""Test that duration="-1" sets expires to None (never expires)."""
|
||||
|
||||
@@ -420,3 +420,42 @@ async def test_add_team_callbacks_no_audit_when_disabled(monkeypatch):
|
||||
)
|
||||
|
||||
assert audit_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_team_callbacks_writes_encrypted_callback_vars(monkeypatch):
|
||||
"""add_team_callbacks must encrypt callback_vars values before the DB write."""
|
||||
from litellm.proxy.common_utils.callback_utils import decrypt_callback_vars
|
||||
|
||||
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
|
||||
mock_prisma = _patch_prisma(_team_row(team_id="team-1", metadata={"logging": []}))
|
||||
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
|
||||
patch("litellm.proxy.proxy_server.litellm_proxy_admin_name", "admin"),
|
||||
patch("litellm.proxy.proxy_server.master_key", None),
|
||||
):
|
||||
await add_team_callbacks(
|
||||
data=AddTeamCallback(
|
||||
callback_name="langfuse",
|
||||
callback_type="success",
|
||||
callback_vars={
|
||||
"langfuse_public_key": "pk-lf-real-public",
|
||||
"langfuse_secret_key": "sk-lf-real-secret",
|
||||
},
|
||||
),
|
||||
http_request=MagicMock(spec=Request),
|
||||
team_id="team-1",
|
||||
user_api_key_dict=_admin_auth(),
|
||||
litellm_changed_by=None,
|
||||
)
|
||||
|
||||
written = json.loads(
|
||||
mock_prisma.db.litellm_teamtable.update.await_args.kwargs["data"]["metadata"]
|
||||
)
|
||||
cv = written["logging"][0]["callback_vars"]
|
||||
assert cv["langfuse_secret_key"] != "sk-lf-real-secret"
|
||||
assert cv["langfuse_public_key"] != "pk-lf-real-public"
|
||||
recovered = decrypt_callback_vars(written)["logging"][0]["callback_vars"]
|
||||
assert recovered["langfuse_secret_key"] == "sk-lf-real-secret"
|
||||
assert recovered["langfuse_public_key"] == "pk-lf-real-public"
|
||||
|
||||
@@ -7945,6 +7945,71 @@ async def test_team_member_me_returns_404_for_unknown_team(mock_db_client):
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_team_encrypts_callback_vars(
|
||||
mock_db_client, mock_admin_auth, monkeypatch
|
||||
):
|
||||
"""/team/new must encrypt callback_vars values before they reach the DB."""
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy._types import NewTeamRequest
|
||||
from litellm.proxy.common_utils.callback_utils import decrypt_callback_vars
|
||||
from litellm.proxy.management_endpoints.team_endpoints import new_team
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
|
||||
|
||||
# Use the real jsonify helpers so the encrypted dict goes through the
|
||||
# actual JSON serialization production uses (catches non-serializable
|
||||
# ciphertext, missing fields, etc.).
|
||||
mock_db_client.jsonify_object = PrismaClient.jsonify_object.__get__(mock_db_client)
|
||||
mock_db_client.jsonify_team_object = PrismaClient.jsonify_team_object.__get__(
|
||||
mock_db_client
|
||||
)
|
||||
mock_db_client.get_data = AsyncMock(return_value=None)
|
||||
mock_db_client.db = MagicMock()
|
||||
mock_db_client.db.litellm_teamtable = MagicMock()
|
||||
team_create_result = MagicMock(team_id="team-456", object_permission_id=None)
|
||||
team_create_result.model_dump.return_value = {"team_id": "team-456"}
|
||||
mock_team_create = AsyncMock(return_value=team_create_result)
|
||||
mock_db_client.db.litellm_teamtable.create = mock_team_create
|
||||
mock_db_client.db.litellm_teamtable.count = AsyncMock(return_value=0)
|
||||
mock_db_client.db.litellm_teamtable.update = AsyncMock(
|
||||
return_value=team_create_result
|
||||
)
|
||||
mock_db_client.db.litellm_usertable = MagicMock()
|
||||
mock_db_client.db.litellm_usertable.update = AsyncMock(return_value=MagicMock())
|
||||
|
||||
team_request = NewTeamRequest(
|
||||
team_alias="my-team",
|
||||
metadata={
|
||||
"logging": [
|
||||
{
|
||||
"callback_name": "langfuse",
|
||||
"callback_type": "success",
|
||||
"callback_vars": {
|
||||
"langfuse_public_key": "pk-real",
|
||||
"langfuse_secret_key": "sk-real",
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
await new_team(
|
||||
data=team_request,
|
||||
http_request=MagicMock(spec=Request),
|
||||
user_api_key_dict=mock_admin_auth,
|
||||
)
|
||||
|
||||
written = mock_team_create.call_args.kwargs["data"]
|
||||
# jsonify_team_object serializes the metadata dict to a JSON string before
|
||||
# the DB write, so we round-trip through json.loads to inspect it.
|
||||
metadata = json.loads(written["metadata"])
|
||||
cv = metadata["logging"][0]["callback_vars"]
|
||||
assert cv["langfuse_secret_key"] != "sk-real"
|
||||
recovered = decrypt_callback_vars(metadata)["logging"][0]["callback_vars"]
|
||||
assert recovered["langfuse_secret_key"] == "sk-real"
|
||||
def _non_admin_auth():
|
||||
return UserAPIKeyAuth(
|
||||
user_id="u-team-admin", user_role=LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
@@ -1549,6 +1549,57 @@ def test_team_dynamic_logging_settings():
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_key_dynamic_logging_settings_decrypts_callback_vars(monkeypatch):
|
||||
"""Encrypted callback_vars on the key are decrypted before downstream use."""
|
||||
from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars
|
||||
|
||||
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
|
||||
encrypted_metadata = encrypt_callback_vars(
|
||||
{
|
||||
"logging": [
|
||||
{
|
||||
"callback_name": "langfuse",
|
||||
"callback_type": "success",
|
||||
"callback_vars": {
|
||||
"langfuse_public_key": "pk-real",
|
||||
"langfuse_secret_key": "sk-real",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
cv_on_disk = encrypted_metadata["logging"][0]["callback_vars"]
|
||||
assert cv_on_disk["langfuse_secret_key"] != "sk-real" # sanity: stored encrypted
|
||||
|
||||
key = UserAPIKeyAuth(api_key="t", metadata=encrypted_metadata, team_metadata={})
|
||||
result = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(key)
|
||||
cv = result[0]["callback_vars"]
|
||||
assert cv["langfuse_secret_key"] == "sk-real"
|
||||
assert cv["langfuse_public_key"] == "pk-real"
|
||||
|
||||
|
||||
def test_team_dynamic_logging_settings_decrypts_callback_vars(monkeypatch):
|
||||
"""Encrypted callback_vars on the team are decrypted before downstream use."""
|
||||
from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars
|
||||
|
||||
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
|
||||
encrypted_team = encrypt_callback_vars(
|
||||
{
|
||||
"logging": [
|
||||
{
|
||||
"callback_name": "langfuse",
|
||||
"callback_type": "failure",
|
||||
"callback_vars": {"langfuse_secret_key": "team-sk-real"},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
key = UserAPIKeyAuth(api_key="t", metadata={}, team_metadata=encrypted_team)
|
||||
result = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(key)
|
||||
assert result[0]["callback_vars"]["langfuse_secret_key"] == "team-sk-real"
|
||||
|
||||
|
||||
def test_get_dynamic_logging_metadata_with_arize_team_logging():
|
||||
"""
|
||||
Test _get_dynamic_logging_metadata function with arize team logging and dynamic parameters
|
||||
|
||||
Reference in New Issue
Block a user