fix(file_search): promote DB helper, suppress sub-call billing, add queries-plural test

- Promote _fetch_managed_vector_stores_by_uuids from @staticmethod to a module-level
  async helper get_managed_vector_store_rows_by_uuids, following the same standalone
  helper pattern as get_team_object / get_key_object so the hot-path DB read is a
  named importable function rather than an inline prisma_client.db.* call
- Pass no-log=True to both inner _call_aresponses sub-calls so they do not fire
  independent billing/monitoring callbacks; cost is accumulated in the synthesized
  response's _hidden_params for the outer responses() call
- Add test_H11b covering the primary queries (plural array) function-tool schema,
  complementing H11 which exercises only the backward-compat singular query path

Made-with: Cursor
This commit is contained in:
Sameer Kankute
2026-03-18 11:38:49 +05:30
parent 76176f2a64
commit 7660f39fdb
3 changed files with 77 additions and 22 deletions
@@ -66,6 +66,23 @@ else:
PrismaClient = Any
async def get_managed_vector_store_rows_by_uuids(
uuids: List[str],
prisma_client: Any,
) -> List[Any]:
"""
Fetch managed vector store rows by their internal UUIDs.
Standalone helper following the same pattern as get_team_object /
get_key_object so that callers on the hot request path use a named,
importable function rather than an inline prisma_client.db.* call.
"""
return await prisma_client.db.litellm_managedvectorstorestable.find_many(
where={"vector_store_id": {"in": uuids}},
take=len(uuids),
)
class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
# Class variables or attributes
def __init__(
@@ -741,23 +758,6 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
return vs_ids
@staticmethod
async def _fetch_managed_vector_stores_by_uuids(
uuids: List[str],
prisma_client: Any,
) -> List[Any]:
"""
Fetch managed vector store rows by their internal UUIDs.
Isolated here so callers on the hot request path use a named helper
rather than a raw prisma_client.db.* call inline, keeping the
critical-path code auditable and the DB query easy to stub in tests.
"""
return await prisma_client.db.litellm_managedvectorstorestable.find_many(
where={"vector_store_id": {"in": uuids}},
take=len(uuids),
)
async def check_vector_store_ids_access(
self,
vector_store_ids: List[str],
@@ -788,7 +788,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
if not uuid_to_unified:
return
rows = await self._fetch_managed_vector_stores_by_uuids(
rows = await get_managed_vector_store_rows_by_uuids(
uuids=list(uuid_to_unified.keys()),
prisma_client=prisma_client,
)
@@ -405,14 +405,15 @@ async def aresponses_with_emulated_file_search(
# 1. Replace file_search tools with function tool
transformed_tools, all_vs_ids = _replace_file_search_tools(tools)
# 2. First provider call — provider will call the file_search function
# 2. First provider call — provider will call the file_search function.
# Pass no-log=True so this internal sub-call does not fire its own billing/
first_response: ResponsesAPIResponse = cast(
ResponsesAPIResponse,
await _call_aresponses(
input=input,
model=model,
tools=transformed_tools or None,
**kwargs,
**{**kwargs, "no-log": True},
),
)
@@ -514,14 +515,15 @@ async def aresponses_with_emulated_file_search(
+ tool_results
)
# 6. Follow-up call — provider writes the final answer given search results
# 6. Follow-up call — provider writes the final answer given search results.
# Suppress callbacks here too; cost is accumulated into the synthesized
final_response: ResponsesAPIResponse = cast(
ResponsesAPIResponse,
await _call_aresponses(
input=follow_up_input,
model=model,
tools=None, # no tools needed for the answer step
**kwargs,
**{**kwargs, "no-log": True},
),
)
@@ -729,6 +729,59 @@ class TestEmulatedFileSearchHandler:
annotations = _get(content0, "annotations")
assert any(_get(a, "file_id") == "file-xyz" for a in annotations)
@pytest.mark.asyncio
async def test_H11b_emulated_full_flow_primary_queries_schema(self):
"""Primary path: provider returns queries (plural array) as defined in the tool schema."""
from litellm.responses.file_search.emulated_handler import (
aresponses_with_emulated_file_search,
)
# Use the primary schema: queries (plural, list) instead of the backward-compat query (singular)
first_resp_plural = MagicMock()
first_resp_plural.output = [
{
"type": "function_call",
"name": "litellm_file_search",
"call_id": "call_plural",
"arguments": '{"queries": ["what is deep research?", "multi-step reasoning"], "vector_store_id": "vs_001"}',
}
]
first_resp_plural.id = "resp_plural"
first_resp_plural.created_at = 1700000000
first_resp_plural.model = "claude-3-5-sonnet"
first_resp_plural.usage = None
final_resp = self._make_mock_responses_api_response(text="Deep research uses multiple queries.")
search_result = MagicMock()
search_result.file_id = "file-multi"
search_result.filename = "multi.pdf"
search_result.score = 0.9
search_result.content = [{"type": "text", "text": "multi-query context"}]
mock_search_response = MagicMock()
mock_search_response.data = [search_result]
with patch(
"litellm.responses.file_search.emulated_handler._call_aresponses",
new=AsyncMock(side_effect=[first_resp_plural, final_resp]),
), patch(
"litellm.vector_stores.main.asearch",
new=AsyncMock(return_value=mock_search_response),
):
result = await aresponses_with_emulated_file_search(
input="What is deep research?",
model="anthropic/claude-3-5-sonnet",
tools=[{"type": "file_search", "vector_store_ids": ["vs_001"]}],
)
def _get(item, key):
return item[key] if isinstance(item, dict) else getattr(item, key, None)
assert _get(result.output[0], "type") == "file_search_call"
# Two queries were issued, both should appear in the output
assert len(_get(result.output[0], "queries")) == 2
assert _get(result.output[1], "type") == "message"
@pytest.mark.asyncio
async def test_H12_emulated_flow_provider_answers_without_tool_call(self):
"""If provider answers directly (no tool call), still return OpenAI format."""