encrypt callback_vars in key/team metadata at rest (#27141)

Co-authored-by: Michael Riad Zaky <michaelr@Michaels-MacBook-Air.local>
Co-authored-by: Yuneng Jiang <yuneng@berri.ai>
This commit is contained in:
Michael-RZ-Berri
2026-05-23 12:15:44 -07:00
committed by GitHub
parent 492891cad8
commit 3b2ce201d8
11 changed files with 456 additions and 9 deletions
+98 -1
View File
@@ -1,16 +1,31 @@
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional
import copy
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional
import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.proxy.types_utils.utils import get_instance_fn
from litellm.types.utils import (
StandardLoggingGuardrailInformation,
StandardLoggingPayload,
)
_CALLBACK_VAR_MASKER = SensitiveDataMasker()
# Compound names that are credential-bearing but don't contain any of the
# default sensitive segments (so SensitiveDataMasker won't flag them).
_EXTRA_SENSITIVE_CALLBACK_KEYS = {"gcs_path_service_account"}
# Sentinel prefix on encrypted callback_var values. Lets us detect
# already-encrypted input cheaply (no decrypt-attempt round trip) and
# avoid double-encrypting if `LITELLM_SALT_KEY` is rotated between writes.
_CALLBACK_VAR_ENCRYPTED_PREFIX = "litellm_enc::"
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
@@ -550,3 +565,85 @@ def normalize_callback_names(callbacks: Iterable[Any]) -> List[Any]:
if callbacks is None:
return []
return [c.lower() if isinstance(c, str) else c for c in callbacks]
def encrypt_callback_vars(metadata: Any) -> Any:
"""Return a deep copy of metadata with callback_vars values encrypted at rest.
Idempotent: a value that already decrypts cleanly is left unchanged so
round-trips through edit forms don't double-encrypt.
"""
return _transform_callback_vars(metadata, _encrypt_if_plaintext)
def decrypt_callback_vars(metadata: Any) -> Any:
"""Return a deep copy of metadata with callback_vars values decrypted.
Legacy plaintext rows pass through unchanged (decrypt failure → original).
"""
return _transform_callback_vars(metadata, _decrypt_or_passthrough)
def _transform_callback_vars(
metadata: Any, transform: Callable[[str, Any], Any]
) -> Any:
if not isinstance(metadata, dict):
return metadata
out = copy.deepcopy(metadata)
logging_entries = out.get("logging")
if isinstance(logging_entries, list):
for entry in logging_entries:
if isinstance(entry, dict) and isinstance(entry.get("callback_vars"), dict):
entry["callback_vars"] = {
k: transform(k, v) for k, v in entry["callback_vars"].items()
}
callback_settings = out.get("callback_settings")
if isinstance(callback_settings, dict) and isinstance(
callback_settings.get("callback_vars"), dict
):
callback_settings["callback_vars"] = {
k: transform(k, v) for k, v in callback_settings["callback_vars"].items()
}
return out
def _is_sensitive_callback_var(key: str) -> bool:
"""Match codebase precedent: only credential-bearing fields get encrypted;
routing/identifier fields (host, base_url, project, region) stay plain."""
if key in _EXTRA_SENSITIVE_CALLBACK_KEYS:
return True
return _CALLBACK_VAR_MASKER.is_sensitive_key(key)
def _encrypt_if_plaintext(key: str, value: Any) -> Any:
if not isinstance(value, str) or not value:
return value
if not _is_sensitive_callback_var(key):
return value
if value.startswith(_CALLBACK_VAR_ENCRYPTED_PREFIX):
# Already encrypted — round-tripping ciphertext (e.g. UI Edit Settings
# save without changing the field) must not double-encrypt. Cheap
# prefix check is robust under salt-key rotation; a decrypt-based
# idempotency check would mis-classify K1-encrypted blobs as
# plaintext under K2 and wrap them a second time.
return value
try:
return _CALLBACK_VAR_ENCRYPTED_PREFIX + encrypt_value_helper(value)
except Exception:
# No salt key / master key configured — leave the value as-is rather
# than crash the write. Dev environments without LITELLM_SALT_KEY hit
# this path; production always has a master key so encryption proceeds.
return value
def _decrypt_or_passthrough(key: str, value: Any) -> Any:
if not isinstance(value, str) or not value:
return value
if not value.startswith(_CALLBACK_VAR_ENCRYPTED_PREFIX):
# Legacy plaintext rows or non-credential fields — return as-is.
return value
inner = value[len(_CALLBACK_VAR_ENCRYPTED_PREFIX) :]
decrypted = decrypt_value_helper(
value=inner, key=key, exception_type="debug", return_original_value=False
)
return decrypted if decrypted is not None else value
+4 -3
View File
@@ -26,6 +26,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.common_utils.callback_utils import (
decrypt_callback_vars,
get_metadata_variable_name_from_kwargs,
)
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
@@ -477,7 +478,7 @@ class KeyAndTeamLoggingSettings:
user_api_key_dict.metadata is not None
and "logging" in user_api_key_dict.metadata
):
return user_api_key_dict.metadata["logging"]
return decrypt_callback_vars(user_api_key_dict.metadata).get("logging")
return None
@staticmethod
@@ -486,7 +487,7 @@ class KeyAndTeamLoggingSettings:
user_api_key_dict.team_metadata is not None
and "logging" in user_api_key_dict.team_metadata
):
return user_api_key_dict.team_metadata["logging"]
return decrypt_callback_vars(user_api_key_dict.team_metadata).get("logging")
return None
@@ -540,7 +541,7 @@ def _get_dynamic_logging_metadata(
}
}
"""
team_metadata = user_api_key_dict.team_metadata
team_metadata = decrypt_callback_vars(user_api_key_dict.team_metadata)
callback_settings = team_metadata.get("callback_settings", None) or {}
callback_settings_obj = TeamCallbackMetadata(**callback_settings)
verbose_proxy_logger.debug(
@@ -51,6 +51,10 @@ from litellm.proxy.auth.auth_checks import (
)
from litellm.proxy.auth.auth_utils import abbreviate_api_key
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.callback_utils import (
decrypt_callback_vars,
encrypt_callback_vars,
)
from litellm.proxy.common_utils.rbac_utils import check_org_admin_can_generate_keys
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
@@ -1752,7 +1756,7 @@ def prepare_metadata_fields(
)
)
non_default_values["metadata"] = casted_metadata
non_default_values["metadata"] = encrypt_callback_vars(casted_metadata)
return non_default_values
@@ -3459,6 +3463,7 @@ async def generate_key_helper_fn( # noqa: PLR0915
metadata = metadata or {}
metadata["prompts"] = prompts
metadata = encrypt_callback_vars(metadata)
metadata_json = json.dumps(metadata)
validate_model_max_budget(model_max_budget)
model_max_budget_json = json.dumps(model_max_budget)
@@ -5942,7 +5947,7 @@ async def key_health(
logging_statuses = await test_key_logging(
user_api_key_dict=user_api_key_dict,
request=request,
key_logging=key_metadata["logging"],
key_logging=decrypt_callback_vars(key_metadata)["logging"],
)
health_status["logging_callbacks"] = logging_statuses
@@ -27,6 +27,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars
from litellm.proxy.management_endpoints.team_endpoints import _verify_team_access
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
@@ -245,6 +246,7 @@ async def add_team_callbacks(
team_callback_settings.append(data.model_dump())
team_metadata["logging"] = team_callback_settings
team_metadata = encrypt_callback_vars(team_metadata)
team_metadata_json = json.dumps(team_metadata) # update team_metadata
new_team_row = await prisma_client.db.litellm_teamtable.update(
@@ -347,6 +349,7 @@ async def disable_team_logging(
# Update metadata
team_metadata["callback_settings"] = team_callback_settings_obj.model_dump()
team_metadata = encrypt_callback_vars(team_metadata)
team_metadata_json = json.dumps(team_metadata)
# Update team in database
@@ -72,6 +72,7 @@ from litellm.proxy.auth.auth_checks import (
get_user_object,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars
from litellm.proxy.management_endpoints.common_utils import (
_check_passthrough_routes_caller_permission,
_is_user_org_admin_for_team,
@@ -1156,6 +1157,11 @@ async def new_team( # noqa: PLR0915
)
complete_team_data_dict["router_settings"] = router_settings_json
if complete_team_data_dict.get("metadata") is not None:
complete_team_data_dict["metadata"] = encrypt_callback_vars(
complete_team_data_dict["metadata"]
)
complete_team_data_dict = prisma_client.jsonify_team_object(
db_data=complete_team_data_dict
)
@@ -1828,6 +1834,9 @@ async def update_team( # noqa: PLR0915
# update team metadata fields
_update_metadata_fields(updated_kv=updated_kv)
if updated_kv.get("metadata") is not None:
updated_kv["metadata"] = encrypt_callback_vars(updated_kv["metadata"])
if "model_aliases" in updated_kv:
updated_kv.pop("model_aliases")
_model_id = await _update_model_table(
+4 -3
View File
@@ -416,9 +416,10 @@ def test_dynamic_turn_off_message_logging(callback_vars):
)
assert callbacks is not None
assert (
callbacks.callback_vars["turn_off_message_logging"]
== callback_vars["turn_off_message_logging"]
# AddTeamCallback's validator stringifies callback_var values, so compare
# against the str() of the input rather than the input bool directly.
assert callbacks.callback_vars["turn_off_message_logging"] == str(
callback_vars["turn_off_message_logging"]
)
@@ -1,3 +1,4 @@
import copy
import sys
import os
from types import SimpleNamespace
@@ -7,6 +8,8 @@ sys.path.insert(
) # Adds the parent directory to the system path
from litellm.proxy.common_utils.callback_utils import (
decrypt_callback_vars,
encrypt_callback_vars,
initialize_callbacks_on_proxy,
get_remaining_tokens_and_requests_from_request_data,
normalize_callback_names,
@@ -119,3 +122,143 @@ def test_initialize_callbacks_on_proxy_instantiates_compression_interception(
assert "compression_interception" not in litellm.callbacks
finally:
litellm.callbacks = original_callbacks
# ---------------------------------------------------------------------------
# encrypt_callback_vars / decrypt_callback_vars
# ---------------------------------------------------------------------------
def _sample_metadata():
return {
"logging": [
{
"callback_name": "langfuse",
"callback_type": "success_and_failure",
"callback_vars": {
"langfuse_public_key": "pk-lf-public",
"langfuse_secret_key": "sk-lf-secret",
"langfuse_host": "https://cloud.langfuse.com",
},
}
],
"callback_settings": {
"callback_vars": {"langsmith_api_key": "ls-api-key"},
},
"tags": ["unrelated"],
}
def _set_salt_key(monkeypatch):
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
def test_encrypt_callback_vars_round_trip(monkeypatch):
_set_salt_key(monkeypatch)
original = _sample_metadata()
encrypted = encrypt_callback_vars(original)
enc_vars = encrypted["logging"][0]["callback_vars"]
assert enc_vars["langfuse_secret_key"] != "sk-lf-secret"
assert enc_vars["langfuse_public_key"] != "pk-lf-public"
assert (
encrypted["callback_settings"]["callback_vars"]["langsmith_api_key"]
!= "ls-api-key"
)
decrypted = decrypt_callback_vars(encrypted)
assert (
decrypted["logging"][0]["callback_vars"]
== original["logging"][0]["callback_vars"]
)
assert (
decrypted["callback_settings"]["callback_vars"]
== original["callback_settings"]["callback_vars"]
)
def test_encrypt_callback_vars_is_idempotent(monkeypatch):
_set_salt_key(monkeypatch)
once = encrypt_callback_vars(_sample_metadata())
twice = encrypt_callback_vars(once)
assert once == twice
def test_encrypt_callback_vars_does_not_mutate_input(monkeypatch):
_set_salt_key(monkeypatch)
original = _sample_metadata()
snapshot = copy.deepcopy(original)
encrypt_callback_vars(original)
assert original == snapshot
def test_decrypt_callback_vars_passes_through_legacy_plaintext(monkeypatch):
_set_salt_key(monkeypatch)
plaintext = _sample_metadata()
decrypted = decrypt_callback_vars(plaintext)
# legacy rows decrypt-fail and fall through unchanged
assert (
decrypted["logging"][0]["callback_vars"]["langfuse_secret_key"]
== "sk-lf-secret"
)
def test_callback_vars_helpers_handle_edge_shapes(monkeypatch):
_set_salt_key(monkeypatch)
assert encrypt_callback_vars(None) is None
assert encrypt_callback_vars({}) == {}
assert decrypt_callback_vars(None) is None
assert decrypt_callback_vars({}) == {}
# logging not a list / callback_vars not a dict — leave alone
weird = {"logging": "not-a-list", "callback_settings": {"callback_vars": None}}
assert encrypt_callback_vars(weird) == weird
# empty/None callback_vars values stay as-is
has_blanks = {
"logging": [
{
"callback_vars": {
"langfuse_public_key": "",
"langfuse_secret_key": None,
"langfuse_host": "https://cloud.langfuse.com",
}
}
]
}
out = encrypt_callback_vars(has_blanks)
cv = out["logging"][0]["callback_vars"]
assert cv["langfuse_public_key"] == ""
assert cv["langfuse_secret_key"] is None
# langfuse_host is a routing field, not a credential — stays plain.
assert cv["langfuse_host"] == "https://cloud.langfuse.com"
def test_encrypt_callback_vars_only_encrypts_credential_fields(monkeypatch):
"""Routing/identifier fields stay plaintext; credential fields encrypt."""
_set_salt_key(monkeypatch)
metadata = {
"logging": [
{
"callback_vars": {
"langfuse_secret_key": "sk-real",
"langfuse_public_key": "pk-real",
"langfuse_host": "https://cloud.langfuse.com",
"langsmith_project": "my-proj",
"langsmith_base_url": "https://smith.example",
"gcs_path_service_account": "{json contents}",
}
}
]
}
cv = encrypt_callback_vars(metadata)["logging"][0]["callback_vars"]
# Sensitive (key-name segments match SensitiveDataMasker patterns):
assert cv["langfuse_secret_key"] != "sk-real"
assert cv["langfuse_public_key"] != "pk-real"
# Sensitive via the explicit gcs override:
assert cv["gcs_path_service_account"] != "{json contents}"
# Routing / identifiers stay plaintext:
assert cv["langfuse_host"] == "https://cloud.langfuse.com"
assert cv["langsmith_project"] == "my-proj"
assert cv["langsmith_base_url"] == "https://smith.example"
@@ -1336,6 +1336,39 @@ async def test_update_without_metadata_still_preserves_existing():
assert result["metadata"]["other"] == "kept"
@pytest.mark.asyncio
async def test_prepare_key_update_data_encrypts_callback_vars(monkeypatch):
"""/key/update must encrypt callback_vars values before they reach the DB."""
from litellm.proxy.common_utils.callback_utils import decrypt_callback_vars
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
data = UpdateKeyRequest(
key="sk-1",
metadata={
"logging": [
{
"callback_name": "langfuse",
"callback_type": "success",
"callback_vars": {
"langfuse_public_key": "pk-real",
"langfuse_secret_key": "sk-real",
},
}
]
},
)
existing_key = LiteLLM_VerificationToken(token="hashed")
result = await prepare_key_update_data(data=data, existing_key_row=existing_key)
cv = result["metadata"]["logging"][0]["callback_vars"]
assert cv["langfuse_secret_key"] != "sk-real"
assert cv["langfuse_public_key"] != "pk-real"
recovered = decrypt_callback_vars(result["metadata"])["logging"][0]["callback_vars"]
assert recovered["langfuse_secret_key"] == "sk-real"
assert recovered["langfuse_public_key"] == "pk-real"
@pytest.mark.asyncio
async def test_prepare_key_update_data_duration_never_expires():
"""Test that duration="-1" sets expires to None (never expires)."""
@@ -420,3 +420,42 @@ async def test_add_team_callbacks_no_audit_when_disabled(monkeypatch):
)
assert audit_calls == []
@pytest.mark.asyncio
async def test_add_team_callbacks_writes_encrypted_callback_vars(monkeypatch):
"""add_team_callbacks must encrypt callback_vars values before the DB write."""
from litellm.proxy.common_utils.callback_utils import decrypt_callback_vars
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
mock_prisma = _patch_prisma(_team_row(team_id="team-1", metadata={"logging": []}))
with (
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.litellm_proxy_admin_name", "admin"),
patch("litellm.proxy.proxy_server.master_key", None),
):
await add_team_callbacks(
data=AddTeamCallback(
callback_name="langfuse",
callback_type="success",
callback_vars={
"langfuse_public_key": "pk-lf-real-public",
"langfuse_secret_key": "sk-lf-real-secret",
},
),
http_request=MagicMock(spec=Request),
team_id="team-1",
user_api_key_dict=_admin_auth(),
litellm_changed_by=None,
)
written = json.loads(
mock_prisma.db.litellm_teamtable.update.await_args.kwargs["data"]["metadata"]
)
cv = written["logging"][0]["callback_vars"]
assert cv["langfuse_secret_key"] != "sk-lf-real-secret"
assert cv["langfuse_public_key"] != "pk-lf-real-public"
recovered = decrypt_callback_vars(written)["logging"][0]["callback_vars"]
assert recovered["langfuse_secret_key"] == "sk-lf-real-secret"
assert recovered["langfuse_public_key"] == "pk-lf-real-public"
@@ -7945,6 +7945,71 @@ async def test_team_member_me_returns_404_for_unknown_team(mock_db_client):
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_new_team_encrypts_callback_vars(
mock_db_client, mock_admin_auth, monkeypatch
):
"""/team/new must encrypt callback_vars values before they reach the DB."""
from fastapi import Request
from litellm.proxy._types import NewTeamRequest
from litellm.proxy.common_utils.callback_utils import decrypt_callback_vars
from litellm.proxy.management_endpoints.team_endpoints import new_team
from litellm.proxy.utils import PrismaClient
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
# Use the real jsonify helpers so the encrypted dict goes through the
# actual JSON serialization production uses (catches non-serializable
# ciphertext, missing fields, etc.).
mock_db_client.jsonify_object = PrismaClient.jsonify_object.__get__(mock_db_client)
mock_db_client.jsonify_team_object = PrismaClient.jsonify_team_object.__get__(
mock_db_client
)
mock_db_client.get_data = AsyncMock(return_value=None)
mock_db_client.db = MagicMock()
mock_db_client.db.litellm_teamtable = MagicMock()
team_create_result = MagicMock(team_id="team-456", object_permission_id=None)
team_create_result.model_dump.return_value = {"team_id": "team-456"}
mock_team_create = AsyncMock(return_value=team_create_result)
mock_db_client.db.litellm_teamtable.create = mock_team_create
mock_db_client.db.litellm_teamtable.count = AsyncMock(return_value=0)
mock_db_client.db.litellm_teamtable.update = AsyncMock(
return_value=team_create_result
)
mock_db_client.db.litellm_usertable = MagicMock()
mock_db_client.db.litellm_usertable.update = AsyncMock(return_value=MagicMock())
team_request = NewTeamRequest(
team_alias="my-team",
metadata={
"logging": [
{
"callback_name": "langfuse",
"callback_type": "success",
"callback_vars": {
"langfuse_public_key": "pk-real",
"langfuse_secret_key": "sk-real",
},
}
]
},
)
await new_team(
data=team_request,
http_request=MagicMock(spec=Request),
user_api_key_dict=mock_admin_auth,
)
written = mock_team_create.call_args.kwargs["data"]
# jsonify_team_object serializes the metadata dict to a JSON string before
# the DB write, so we round-trip through json.loads to inspect it.
metadata = json.loads(written["metadata"])
cv = metadata["logging"][0]["callback_vars"]
assert cv["langfuse_secret_key"] != "sk-real"
recovered = decrypt_callback_vars(metadata)["logging"][0]["callback_vars"]
assert recovered["langfuse_secret_key"] == "sk-real"
def _non_admin_auth():
return UserAPIKeyAuth(
user_id="u-team-admin", user_role=LitellmUserRoles.INTERNAL_USER
@@ -1549,6 +1549,57 @@ def test_team_dynamic_logging_settings():
assert result is None
def test_key_dynamic_logging_settings_decrypts_callback_vars(monkeypatch):
"""Encrypted callback_vars on the key are decrypted before downstream use."""
from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
encrypted_metadata = encrypt_callback_vars(
{
"logging": [
{
"callback_name": "langfuse",
"callback_type": "success",
"callback_vars": {
"langfuse_public_key": "pk-real",
"langfuse_secret_key": "sk-real",
},
}
]
}
)
cv_on_disk = encrypted_metadata["logging"][0]["callback_vars"]
assert cv_on_disk["langfuse_secret_key"] != "sk-real" # sanity: stored encrypted
key = UserAPIKeyAuth(api_key="t", metadata=encrypted_metadata, team_metadata={})
result = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(key)
cv = result[0]["callback_vars"]
assert cv["langfuse_secret_key"] == "sk-real"
assert cv["langfuse_public_key"] == "pk-real"
def test_team_dynamic_logging_settings_decrypts_callback_vars(monkeypatch):
"""Encrypted callback_vars on the team are decrypted before downstream use."""
from litellm.proxy.common_utils.callback_utils import encrypt_callback_vars
monkeypatch.setenv("LITELLM_SALT_KEY", "test-salt-32-bytes-aaaaaaaaaaaaaa")
encrypted_team = encrypt_callback_vars(
{
"logging": [
{
"callback_name": "langfuse",
"callback_type": "failure",
"callback_vars": {"langfuse_secret_key": "team-sk-real"},
}
]
}
)
key = UserAPIKeyAuth(api_key="t", metadata={}, team_metadata=encrypted_team)
result = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(key)
assert result[0]["callback_vars"]["langfuse_secret_key"] == "team-sk-real"
def test_get_dynamic_logging_metadata_with_arize_team_logging():
"""
Test _get_dynamic_logging_metadata function with arize team logging and dynamic parameters