mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-28 13:11:20 +00:00
fix(file_search): promote DB helper, suppress sub-call billing, add queries-plural test
- Promote _fetch_managed_vector_stores_by_uuids from @staticmethod to a module-level async helper get_managed_vector_store_rows_by_uuids, following the same standalone helper pattern as get_team_object / get_key_object so the hot-path DB read is a named importable function rather than an inline prisma_client.db.* call - Pass no-log=True to both inner _call_aresponses sub-calls so they do not fire independent billing/monitoring callbacks; cost is accumulated in the synthesized response's _hidden_params for the outer responses() call - Add test_H11b covering the primary queries (plural array) function-tool schema, complementing H11 which exercises only the backward-compat singular query path Made-with: Cursor
This commit is contained in:
@@ -66,6 +66,23 @@ else:
|
||||
PrismaClient = Any
|
||||
|
||||
|
||||
async def get_managed_vector_store_rows_by_uuids(
|
||||
uuids: List[str],
|
||||
prisma_client: Any,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Fetch managed vector store rows by their internal UUIDs.
|
||||
|
||||
Standalone helper following the same pattern as get_team_object /
|
||||
get_key_object so that callers on the hot request path use a named,
|
||||
importable function rather than an inline prisma_client.db.* call.
|
||||
"""
|
||||
return await prisma_client.db.litellm_managedvectorstorestable.find_many(
|
||||
where={"vector_store_id": {"in": uuids}},
|
||||
take=len(uuids),
|
||||
)
|
||||
|
||||
|
||||
class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
@@ -741,23 +758,6 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
|
||||
|
||||
return vs_ids
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_managed_vector_stores_by_uuids(
|
||||
uuids: List[str],
|
||||
prisma_client: Any,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Fetch managed vector store rows by their internal UUIDs.
|
||||
|
||||
Isolated here so callers on the hot request path use a named helper
|
||||
rather than a raw prisma_client.db.* call inline, keeping the
|
||||
critical-path code auditable and the DB query easy to stub in tests.
|
||||
"""
|
||||
return await prisma_client.db.litellm_managedvectorstorestable.find_many(
|
||||
where={"vector_store_id": {"in": uuids}},
|
||||
take=len(uuids),
|
||||
)
|
||||
|
||||
async def check_vector_store_ids_access(
|
||||
self,
|
||||
vector_store_ids: List[str],
|
||||
@@ -788,7 +788,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
|
||||
if not uuid_to_unified:
|
||||
return
|
||||
|
||||
rows = await self._fetch_managed_vector_stores_by_uuids(
|
||||
rows = await get_managed_vector_store_rows_by_uuids(
|
||||
uuids=list(uuid_to_unified.keys()),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
@@ -405,14 +405,15 @@ async def aresponses_with_emulated_file_search(
|
||||
# 1. Replace file_search tools with function tool
|
||||
transformed_tools, all_vs_ids = _replace_file_search_tools(tools)
|
||||
|
||||
# 2. First provider call — provider will call the file_search function
|
||||
# 2. First provider call — provider will call the file_search function.
|
||||
# Pass no-log=True so this internal sub-call does not fire its own billing/
|
||||
first_response: ResponsesAPIResponse = cast(
|
||||
ResponsesAPIResponse,
|
||||
await _call_aresponses(
|
||||
input=input,
|
||||
model=model,
|
||||
tools=transformed_tools or None,
|
||||
**kwargs,
|
||||
**{**kwargs, "no-log": True},
|
||||
),
|
||||
)
|
||||
|
||||
@@ -514,14 +515,15 @@ async def aresponses_with_emulated_file_search(
|
||||
+ tool_results
|
||||
)
|
||||
|
||||
# 6. Follow-up call — provider writes the final answer given search results
|
||||
# 6. Follow-up call — provider writes the final answer given search results.
|
||||
# Suppress callbacks here too; cost is accumulated into the synthesized
|
||||
final_response: ResponsesAPIResponse = cast(
|
||||
ResponsesAPIResponse,
|
||||
await _call_aresponses(
|
||||
input=follow_up_input,
|
||||
model=model,
|
||||
tools=None, # no tools needed for the answer step
|
||||
**kwargs,
|
||||
**{**kwargs, "no-log": True},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -729,6 +729,59 @@ class TestEmulatedFileSearchHandler:
|
||||
annotations = _get(content0, "annotations")
|
||||
assert any(_get(a, "file_id") == "file-xyz" for a in annotations)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_H11b_emulated_full_flow_primary_queries_schema(self):
|
||||
"""Primary path: provider returns queries (plural array) as defined in the tool schema."""
|
||||
from litellm.responses.file_search.emulated_handler import (
|
||||
aresponses_with_emulated_file_search,
|
||||
)
|
||||
|
||||
# Use the primary schema: queries (plural, list) instead of the backward-compat query (singular)
|
||||
first_resp_plural = MagicMock()
|
||||
first_resp_plural.output = [
|
||||
{
|
||||
"type": "function_call",
|
||||
"name": "litellm_file_search",
|
||||
"call_id": "call_plural",
|
||||
"arguments": '{"queries": ["what is deep research?", "multi-step reasoning"], "vector_store_id": "vs_001"}',
|
||||
}
|
||||
]
|
||||
first_resp_plural.id = "resp_plural"
|
||||
first_resp_plural.created_at = 1700000000
|
||||
first_resp_plural.model = "claude-3-5-sonnet"
|
||||
first_resp_plural.usage = None
|
||||
|
||||
final_resp = self._make_mock_responses_api_response(text="Deep research uses multiple queries.")
|
||||
|
||||
search_result = MagicMock()
|
||||
search_result.file_id = "file-multi"
|
||||
search_result.filename = "multi.pdf"
|
||||
search_result.score = 0.9
|
||||
search_result.content = [{"type": "text", "text": "multi-query context"}]
|
||||
mock_search_response = MagicMock()
|
||||
mock_search_response.data = [search_result]
|
||||
|
||||
with patch(
|
||||
"litellm.responses.file_search.emulated_handler._call_aresponses",
|
||||
new=AsyncMock(side_effect=[first_resp_plural, final_resp]),
|
||||
), patch(
|
||||
"litellm.vector_stores.main.asearch",
|
||||
new=AsyncMock(return_value=mock_search_response),
|
||||
):
|
||||
result = await aresponses_with_emulated_file_search(
|
||||
input="What is deep research?",
|
||||
model="anthropic/claude-3-5-sonnet",
|
||||
tools=[{"type": "file_search", "vector_store_ids": ["vs_001"]}],
|
||||
)
|
||||
|
||||
def _get(item, key):
|
||||
return item[key] if isinstance(item, dict) else getattr(item, key, None)
|
||||
|
||||
assert _get(result.output[0], "type") == "file_search_call"
|
||||
# Two queries were issued, both should appear in the output
|
||||
assert len(_get(result.output[0], "queries")) == 2
|
||||
assert _get(result.output[1], "type") == "message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_H12_emulated_flow_provider_answers_without_tool_call(self):
|
||||
"""If provider answers directly (no tool call), still return OpenAI format."""
|
||||
|
||||
Reference in New Issue
Block a user