mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 14:48:44 +00:00
fix(proxy): scope passthrough post-call guardrail buffering to the request
Buffering the Bedrock event stream into a single non-streaming response was gated on whether any post_call guardrail existed globally, so every converse-stream request lost streaming once any post_call guardrail was registered, even for keys that did not reference it. Mirror the gate used by post_call_success_hook (should_run_guardrail against the request's merged guardrails) so only requests whose key/team actually trigger a post_call guardrail are buffered.
This commit is contained in:
@@ -1751,20 +1751,30 @@ class ProxyBaseLLMRequestProcessing:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _has_post_call_guardrails_for_passthrough() -> bool:
|
||||
def _has_post_call_guardrails_for_passthrough(self) -> bool:
|
||||
"""
|
||||
True when any guardrail runs at post_call for passthrough responses.
|
||||
True when a post_call guardrail will actually run for THIS request.
|
||||
|
||||
Unlike _has_post_call_guardrails, an event_hook=None guardrail counts:
|
||||
should_run_guardrail treats it as matching every hook (post_call
|
||||
included), so skipping the passthrough buffer here would forward the
|
||||
raw upstream body and bypass that guardrail's output processing.
|
||||
Mirrors the gate in ProxyLogging.post_call_success_hook
|
||||
(should_run_guardrail against the request's merged guardrails) so that a
|
||||
guardrail registered globally but not configured for this key/team does
|
||||
not force the passthrough stream to be buffered into a single
|
||||
non-streaming response. An event_hook=None guardrail still counts here
|
||||
because should_run_guardrail treats it as matching every hook.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
from litellm.proxy.utils import _check_and_merge_model_level_guardrails
|
||||
|
||||
guardrail_data = _check_and_merge_model_level_guardrails(
|
||||
data=self.data, llm_router=llm_router
|
||||
)
|
||||
for cb in litellm.callbacks:
|
||||
if not isinstance(cb, CustomGuardrail):
|
||||
continue
|
||||
if cb._event_hook_is_event_type(GuardrailEventHooks.post_call):
|
||||
if cb.should_run_guardrail(
|
||||
data=guardrail_data,
|
||||
event_type=GuardrailEventHooks.post_call,
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -159,43 +159,52 @@ class TestHasPostCallGuardrailsForPassthrough:
|
||||
|
||||
Those guardrails run at post_call (should_run_guardrail treats None as
|
||||
matching every hook); skipping the buffer would forward the raw upstream
|
||||
body and bypass output processing.
|
||||
body and bypass output processing. The check is scoped to the request via
|
||||
should_run_guardrail so a guardrail that exists globally but is not
|
||||
configured for this key/team does not turn the stream non-streaming.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _has(data: dict) -> bool:
|
||||
return ProxyBaseLLMRequestProcessing(
|
||||
data=data
|
||||
)._has_post_call_guardrails_for_passthrough()
|
||||
|
||||
def test_returns_true_for_event_hook_none(self):
|
||||
with patch("litellm.callbacks", [AllEventsGuardrail()]):
|
||||
assert (
|
||||
ProxyBaseLLMRequestProcessing._has_post_call_guardrails_for_passthrough()
|
||||
is True
|
||||
)
|
||||
assert self._has({}) is True
|
||||
|
||||
def test_returns_true_for_post_call_guardrail(self):
|
||||
with patch("litellm.callbacks", [PostCallGuardrail()]):
|
||||
assert (
|
||||
ProxyBaseLLMRequestProcessing._has_post_call_guardrails_for_passthrough()
|
||||
is True
|
||||
)
|
||||
assert self._has({}) is True
|
||||
|
||||
def test_returns_false_for_pre_call_only(self):
|
||||
with patch("litellm.callbacks", [PreCallGuardrail()]):
|
||||
assert (
|
||||
ProxyBaseLLMRequestProcessing._has_post_call_guardrails_for_passthrough()
|
||||
is False
|
||||
)
|
||||
assert self._has({}) is False
|
||||
|
||||
def test_returns_false_for_no_callbacks(self):
|
||||
with patch("litellm.callbacks", []):
|
||||
assert (
|
||||
ProxyBaseLLMRequestProcessing._has_post_call_guardrails_for_passthrough()
|
||||
is False
|
||||
)
|
||||
assert self._has({}) is False
|
||||
|
||||
def test_ignores_non_guardrail_callbacks(self):
|
||||
with patch("litellm.callbacks", ["langfuse", CustomLogger()]):
|
||||
assert (
|
||||
ProxyBaseLLMRequestProcessing._has_post_call_guardrails_for_passthrough()
|
||||
is False
|
||||
)
|
||||
assert self._has({}) is False
|
||||
|
||||
def test_request_scoped_guardrail_not_configured_for_key(self):
|
||||
"""A non-default-on post_call guardrail must not force buffering for a
|
||||
request whose key/team does not reference it."""
|
||||
|
||||
class OptInPostCall(CustomGuardrail):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
guardrail_name="opt-in-post",
|
||||
default_on=False,
|
||||
event_hook=GuardrailEventHooks.post_call,
|
||||
)
|
||||
|
||||
with patch("litellm.callbacks", [OptInPostCall()]):
|
||||
assert self._has({"metadata": {"guardrails": []}}) is False
|
||||
assert self._has({"metadata": {"guardrails": ["opt-in-post"]}}) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user