From 7660f39fdbccfa0353b3cdfdcfb02cdc08931c24 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 11:38:49 +0530 Subject: [PATCH] fix(file_search): promote DB helper, suppress sub-call billing, add queries-plural test - Promote _fetch_managed_vector_stores_by_uuids from @staticmethod to a module-level async helper get_managed_vector_store_rows_by_uuids, following the same standalone helper pattern as get_team_object / get_key_object so the hot-path DB read is a named importable function rather than an inline prisma_client.db.* call - Pass no-log=True to both inner _call_aresponses sub-calls so they do not fire independent billing/monitoring callbacks; cost is accumulated in the synthesized response's _hidden_params for the outer responses() call - Add test_H11b covering the primary queries (plural array) function-tool schema, complementing H11 which exercises only the backward-compat singular query path Made-with: Cursor --- .../proxy/hooks/managed_files.py | 36 ++++++------- .../responses/file_search/emulated_handler.py | 10 ++-- .../llms/test_file_search_responses.py | 53 +++++++++++++++++++ 3 files changed, 77 insertions(+), 22 deletions(-) diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py index eecaddade7..12e36fccde 100644 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py +++ b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py @@ -66,6 +66,23 @@ else: PrismaClient = Any +async def get_managed_vector_store_rows_by_uuids( + uuids: List[str], + prisma_client: Any, +) -> List[Any]: + """ + Fetch managed vector store rows by their internal UUIDs. + + Standalone helper following the same pattern as get_team_object / + get_key_object so that callers on the hot request path use a named, + importable function rather than an inline prisma_client.db.* call. + """ + return await prisma_client.db.litellm_managedvectorstorestable.find_many( + where={"vector_store_id": {"in": uuids}}, + take=len(uuids), + ) + + class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): # Class variables or attributes def __init__( @@ -741,23 +758,6 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): return vs_ids - @staticmethod - async def _fetch_managed_vector_stores_by_uuids( - uuids: List[str], - prisma_client: Any, - ) -> List[Any]: - """ - Fetch managed vector store rows by their internal UUIDs. - - Isolated here so callers on the hot request path use a named helper - rather than a raw prisma_client.db.* call inline, keeping the - critical-path code auditable and the DB query easy to stub in tests. - """ - return await prisma_client.db.litellm_managedvectorstorestable.find_many( - where={"vector_store_id": {"in": uuids}}, - take=len(uuids), - ) - async def check_vector_store_ids_access( self, vector_store_ids: List[str], @@ -788,7 +788,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): if not uuid_to_unified: return - rows = await self._fetch_managed_vector_stores_by_uuids( + rows = await get_managed_vector_store_rows_by_uuids( uuids=list(uuid_to_unified.keys()), prisma_client=prisma_client, ) diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py index 62452539d1..70f38e63a6 100644 --- a/litellm/responses/file_search/emulated_handler.py +++ b/litellm/responses/file_search/emulated_handler.py @@ -405,14 +405,15 @@ async def aresponses_with_emulated_file_search( # 1. Replace file_search tools with function tool transformed_tools, all_vs_ids = _replace_file_search_tools(tools) - # 2. First provider call — provider will call the file_search function + # 2. First provider call — provider will call the file_search function. + # Pass no-log=True so this internal sub-call does not fire its own billing/ first_response: ResponsesAPIResponse = cast( ResponsesAPIResponse, await _call_aresponses( input=input, model=model, tools=transformed_tools or None, - **kwargs, + **{**kwargs, "no-log": True}, ), ) @@ -514,14 +515,15 @@ async def aresponses_with_emulated_file_search( + tool_results ) - # 6. Follow-up call — provider writes the final answer given search results + # 6. Follow-up call — provider writes the final answer given search results. + # Suppress callbacks here too; cost is accumulated into the synthesized final_response: ResponsesAPIResponse = cast( ResponsesAPIResponse, await _call_aresponses( input=follow_up_input, model=model, tools=None, # no tools needed for the answer step - **kwargs, + **{**kwargs, "no-log": True}, ), ) diff --git a/tests/test_litellm/llms/test_file_search_responses.py b/tests/test_litellm/llms/test_file_search_responses.py index 3a8fa95be5..ebe1466fa6 100644 --- a/tests/test_litellm/llms/test_file_search_responses.py +++ b/tests/test_litellm/llms/test_file_search_responses.py @@ -729,6 +729,59 @@ class TestEmulatedFileSearchHandler: annotations = _get(content0, "annotations") assert any(_get(a, "file_id") == "file-xyz" for a in annotations) + @pytest.mark.asyncio + async def test_H11b_emulated_full_flow_primary_queries_schema(self): + """Primary path: provider returns queries (plural array) as defined in the tool schema.""" + from litellm.responses.file_search.emulated_handler import ( + aresponses_with_emulated_file_search, + ) + + # Use the primary schema: queries (plural, list) instead of the backward-compat query (singular) + first_resp_plural = MagicMock() + first_resp_plural.output = [ + { + "type": "function_call", + "name": "litellm_file_search", + "call_id": "call_plural", + "arguments": '{"queries": ["what is deep research?", "multi-step reasoning"], "vector_store_id": "vs_001"}', + } + ] + first_resp_plural.id = "resp_plural" + first_resp_plural.created_at = 1700000000 + first_resp_plural.model = "claude-3-5-sonnet" + first_resp_plural.usage = None + + final_resp = self._make_mock_responses_api_response(text="Deep research uses multiple queries.") + + search_result = MagicMock() + search_result.file_id = "file-multi" + search_result.filename = "multi.pdf" + search_result.score = 0.9 + search_result.content = [{"type": "text", "text": "multi-query context"}] + mock_search_response = MagicMock() + mock_search_response.data = [search_result] + + with patch( + "litellm.responses.file_search.emulated_handler._call_aresponses", + new=AsyncMock(side_effect=[first_resp_plural, final_resp]), + ), patch( + "litellm.vector_stores.main.asearch", + new=AsyncMock(return_value=mock_search_response), + ): + result = await aresponses_with_emulated_file_search( + input="What is deep research?", + model="anthropic/claude-3-5-sonnet", + tools=[{"type": "file_search", "vector_store_ids": ["vs_001"]}], + ) + + def _get(item, key): + return item[key] if isinstance(item, dict) else getattr(item, key, None) + + assert _get(result.output[0], "type") == "file_search_call" + # Two queries were issued, both should appear in the output + assert len(_get(result.output[0], "queries")) == 2 + assert _get(result.output[1], "type") == "message" + @pytest.mark.asyncio async def test_H12_emulated_flow_provider_answers_without_tool_call(self): """If provider answers directly (no tool call), still return OpenAI format."""