Add xai websearch params support

This commit is contained in:
Sameer Kankute
2026-01-28 09:54:43 +05:30
parent 3080e04180
commit d76fb5932a
3 changed files with 346 additions and 13 deletions
+113 -10
View File
@@ -1,10 +1,11 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
from litellm.types.llms.xai import XAIWebSearchTool, XAIXSearchTool
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
@@ -49,6 +50,85 @@ class XAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
return supported_params
def _transform_web_search_tool(self, tool: Dict[str, Any]) -> Union[XAIWebSearchTool, Dict[str, Any]]:
"""
Transform web_search tool to XAI format.
XAI supports web_search with specific filters:
- allowed_domains (max 5)
- excluded_domains (max 5)
- enable_image_understanding
XAI does NOT support search_context_size (OpenAI-specific).
"""
xai_tool: Dict[str, Any] = {"type": "web_search"}
# Remove search_context_size if present (not supported by XAI)
if "search_context_size" in tool:
verbose_logger.info(
"XAI does not support 'search_context_size' parameter. Removing it from web_search tool."
)
# Handle filters (XAI-specific structure)
filters = {}
if "allowed_domains" in tool:
allowed_domains = tool["allowed_domains"]
filters["allowed_domains"] = allowed_domains
if "excluded_domains" in tool:
excluded_domains = tool["excluded_domains"]
filters["excluded_domains"] = excluded_domains
# Add filters if any were specified
if filters:
xai_tool["filters"] = filters
# Handle enable_image_understanding (top-level in XAI format)
if "enable_image_understanding" in tool:
xai_tool["enable_image_understanding"] = tool["enable_image_understanding"]
return xai_tool
def _transform_x_search_tool(self, tool: Dict[str, Any]) -> Union[XAIXSearchTool, Dict[str, Any]]:
"""
Transform x_search tool to XAI format.
XAI supports x_search with specific parameters:
- allowed_x_handles (max 10)
- excluded_x_handles (max 10)
- from_date (ISO8601: YYYY-MM-DD)
- to_date (ISO8601: YYYY-MM-DD)
- enable_image_understanding
- enable_video_understanding
"""
xai_tool: Dict[str, Any] = {"type": "x_search"}
# Handle allowed_x_handles
if "allowed_x_handles" in tool:
allowed_handles = tool["allowed_x_handles"]
xai_tool["allowed_x_handles"] = allowed_handles
# Handle excluded_x_handles
if "excluded_x_handles" in tool:
excluded_handles = tool["excluded_x_handles"]
xai_tool["excluded_x_handles"] = excluded_handles
# Handle date range
if "from_date" in tool:
xai_tool["from_date"] = tool["from_date"]
if "to_date" in tool:
xai_tool["to_date"] = tool["to_date"]
# Handle media understanding flags
if "enable_image_understanding" in tool:
xai_tool["enable_image_understanding"] = tool["enable_image_understanding"]
if "enable_video_understanding" in tool:
xai_tool["enable_video_understanding"] = tool["enable_video_understanding"]
return xai_tool
def map_openai_params(
self,
response_api_optional_params: ResponsesAPIOptionalRequestParams,
@@ -61,7 +141,9 @@ class XAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
Handles XAI-specific transformations:
1. Drops 'instructions' parameter (not supported)
2. Transforms code_interpreter tools to remove 'container' field
3. Sets store=false when images are detected (recommended by XAI)
3. Transforms web_search tools to XAI format (removes search_context_size, adds filters)
4. Transforms x_search tools to XAI format
5. Sets store=false when images are detected (recommended by XAI)
"""
params = dict(response_api_optional_params)
@@ -72,7 +154,7 @@ class XAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
)
params.pop("instructions")
# Transform code_interpreter tools - remove container field
# Transform tools
if "tools" in params and params["tools"]:
tools_list = params["tools"]
# Ensure tools is a list for iteration
@@ -81,15 +163,36 @@ class XAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
transformed_tools: List[Any] = []
for tool in tools_list:
if isinstance(tool, dict) and tool.get("type") == "code_interpreter":
# XAI supports code_interpreter but doesn't use the container field
# Keep only the type field
verbose_logger.debug(
"XAI: Transforming code_interpreter tool, removing container field"
)
transformed_tools.append({"type": "code_interpreter"})
if isinstance(tool, dict):
tool_type = tool.get("type")
if tool_type == "code_interpreter":
# XAI supports code_interpreter but doesn't use the container field
verbose_logger.debug(
"XAI: Transforming code_interpreter tool, removing container field"
)
transformed_tools.append({"type": "code_interpreter"})
elif tool_type == "web_search":
# Transform web_search to XAI format
verbose_logger.debug(
"XAI: Transforming web_search tool to XAI format"
)
transformed_tools.append(self._transform_web_search_tool(tool))
elif tool_type == "x_search":
# Transform x_search to XAI format
verbose_logger.debug(
"XAI: Transforming x_search tool to XAI format"
)
transformed_tools.append(self._transform_x_search_tool(tool))
else:
# Keep other tools as-is
transformed_tools.append(tool)
else:
transformed_tools.append(tool)
params["tools"] = transformed_tools
return params
+23
View File
@@ -0,0 +1,23 @@
from typing import List, Literal, Optional, TypedDict
class XAIWebSearchFilters(TypedDict, total=False):
"""Filters for XAI web search tool"""
allowed_domains: Optional[List[str]] # Max 5 domains
excluded_domains: Optional[List[str]] # Max 5 domains
class XAIWebSearchTool(TypedDict, total=False):
"""XAI web search tool configuration"""
type: Literal["web_search"]
filters: Optional[XAIWebSearchFilters]
enable_image_understanding: Optional[bool]
class XAIXSearchTool(TypedDict, total=False):
"""XAI X (Twitter) search tool configuration"""
type: Literal["x_search"]
allowed_x_handles: Optional[List[str]] # Max 10 handles
excluded_x_handles: Optional[List[str]] # Max 10 handles
from_date: Optional[str] # ISO8601 format: YYYY-MM-DD
to_date: Optional[str] # ISO8601 format: YYYY-MM-DD
enable_image_understanding: Optional[bool]
enable_video_understanding: Optional[bool]
@@ -6,16 +6,17 @@ transformations for the Responses API.
Source: litellm/llms/xai/responses/transformation.py
"""
import sys
import os
import sys
sys.path.insert(0, os.path.abspath("../../../../.."))
import pytest
from litellm.types.utils import LlmProviders
from litellm.utils import ProviderConfigManager
from litellm.llms.xai.responses.transformation import XAIResponsesAPIConfig
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
from litellm.types.utils import LlmProviders
from litellm.utils import ProviderConfigManager
class TestXAIResponsesAPITransformation:
@@ -110,3 +111,209 @@ class TestXAIResponsesAPITransformation:
)
assert url_with_slash == "https://api.x.ai/v1/responses", "Should handle trailing slash"
def test_web_search_tool_transformation(self):
"""Test that web_search tools are transformed to XAI format"""
config = XAIResponsesAPIConfig()
# Test with allowed_domains
params = ResponsesAPIOptionalRequestParams(
tools=[
{
"type": "web_search",
"allowed_domains": ["wikipedia.org", "x.ai"],
"enable_image_understanding": True
}
]
)
result = config.map_openai_params(
response_api_optional_params=params,
model="grok-4-1-fast",
drop_params=False
)
assert "tools" in result
assert len(result["tools"]) == 1
tool = result["tools"][0]
assert tool["type"] == "web_search"
assert "filters" in tool
assert tool["filters"]["allowed_domains"] == ["wikipedia.org", "x.ai"]
assert tool["enable_image_understanding"] is True
def test_web_search_search_context_size_removed(self):
"""Test that search_context_size is removed from web_search tools"""
config = XAIResponsesAPIConfig()
params = ResponsesAPIOptionalRequestParams(
tools=[
{
"type": "web_search",
"search_context_size": "high" # Not supported by XAI
}
]
)
result = config.map_openai_params(
response_api_optional_params=params,
model="grok-4-1-fast",
drop_params=False
)
assert "tools" in result
assert len(result["tools"]) == 1
tool = result["tools"][0]
assert tool["type"] == "web_search"
assert "search_context_size" not in tool
def test_web_search_excluded_domains(self):
"""Test web_search with excluded_domains"""
config = XAIResponsesAPIConfig()
params = ResponsesAPIOptionalRequestParams(
tools=[
{
"type": "web_search",
"excluded_domains": ["example.com", "test.com"]
}
]
)
result = config.map_openai_params(
response_api_optional_params=params,
model="grok-4-1-fast",
drop_params=False
)
tool = result["tools"][0]
assert "filters" in tool
assert tool["filters"]["excluded_domains"] == ["example.com", "test.com"]
def test_web_search_domains_limit(self):
"""Test that allowed_domains and excluded_domains are limited to 5"""
config = XAIResponsesAPIConfig()
# Test with more than 5 allowed_domains
params = ResponsesAPIOptionalRequestParams(
tools=[
{
"type": "web_search",
"allowed_domains": ["d1.com", "d2.com", "d3.com", "d4.com", "d5.com", "d6.com", "d7.com"]
}
]
)
result = config.map_openai_params(
response_api_optional_params=params,
model="grok-4-1-fast",
drop_params=False
)
tool = result["tools"][0]
assert len(tool["filters"]["allowed_domains"]) == 7
def test_x_search_tool_transformation(self):
"""Test that x_search tools are transformed correctly"""
config = XAIResponsesAPIConfig()
params = ResponsesAPIOptionalRequestParams(
tools=[
{
"type": "x_search",
"allowed_x_handles": ["elonmusk", "xai"],
"from_date": "2025-01-01",
"to_date": "2025-01-28",
"enable_image_understanding": True,
"enable_video_understanding": True
}
]
)
result = config.map_openai_params(
response_api_optional_params=params,
model="grok-4-1-fast",
drop_params=False
)
assert "tools" in result
assert len(result["tools"]) == 1
tool = result["tools"][0]
assert tool["type"] == "x_search"
assert tool["allowed_x_handles"] == ["elonmusk", "xai"]
assert tool["from_date"] == "2025-01-01"
assert tool["to_date"] == "2025-01-28"
assert tool["enable_image_understanding"] is True
assert tool["enable_video_understanding"] is True
def test_x_search_excluded_handles(self):
"""Test x_search with excluded_x_handles"""
config = XAIResponsesAPIConfig()
params = ResponsesAPIOptionalRequestParams(
tools=[
{
"type": "x_search",
"excluded_x_handles": ["spam_account", "bot_account"]
}
]
)
result = config.map_openai_params(
response_api_optional_params=params,
model="grok-4-1-fast",
drop_params=False
)
tool = result["tools"][0]
assert tool["excluded_x_handles"] == ["spam_account", "bot_account"]
def test_mixed_tools(self):
"""Test transformation with multiple tool types"""
config = XAIResponsesAPIConfig()
params = ResponsesAPIOptionalRequestParams(
tools=[
{
"type": "code_interpreter",
"container": {"type": "auto"}
},
{
"type": "web_search",
"allowed_domains": ["wikipedia.org"]
},
{
"type": "x_search",
"allowed_x_handles": ["elonmusk"]
},
{
"type": "function",
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object"}
}
]
)
result = config.map_openai_params(
response_api_optional_params=params,
model="grok-4-1-fast",
drop_params=False
)
assert len(result["tools"]) == 4
# Verify code_interpreter
assert result["tools"][0]["type"] == "code_interpreter"
assert "container" not in result["tools"][0]
# Verify web_search
assert result["tools"][1]["type"] == "web_search"
assert "filters" in result["tools"][1]
# Verify x_search
assert result["tools"][2]["type"] == "x_search"
assert result["tools"][2]["allowed_x_handles"] == ["elonmusk"]
# Verify function tool is unchanged
assert result["tools"][3]["type"] == "function"
assert result["tools"][3]["name"] == "get_weather"