Merge pull request #14077 from TomeHirata/codex/add-support-for-anthropic-citation-api

Add support for anthropic citation api in Databricks
This commit is contained in:
Krish Dholakia
2025-09-04 20:46:11 -07:00
committed by GitHub
4 changed files with 146 additions and 9 deletions
@@ -282,6 +282,11 @@ ModelResponse(
)
```
### Citations
Anthropic models served through Databricks can return citation metadata. LiteLLM
exposes these via `response.choices[0].message.provider_specific_fields["citations"]`.
### Pass `thinking` to Anthropic models
You can also pass the `thinking` parameter to Anthropic models.
+37 -3
View File
@@ -26,7 +26,6 @@ from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response impo
_should_convert_tool_call_to_json_mode,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
strip_name_from_messages,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
@@ -301,7 +300,6 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
"""
Databricks does not support:
- content in list format.
- 'name' in user message.
"""
new_messages = []
@@ -311,7 +309,6 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
else:
_message = message
new_messages.append(_message)
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages)
new_messages = strip_name_from_messages(new_messages)
if is_async:
@@ -379,6 +376,25 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
thinking_blocks.append(thinking_block)
return reasoning_content, thinking_blocks
@staticmethod
def extract_citations(
content: Optional[AllDatabricksContentValues],
) -> Optional[List[Any]]:
if content is None:
return None
citations = []
if isinstance(content, list):
for item in content:
text = item.get("text", None)
if citations_item := item.get("citations"):
citations.append(
[
{**citation, "supported_text": text}
for citation in citations_item
]
)
return citations or None
def _transform_dbrx_choices(
self, choices: List[DatabricksChoice], json_mode: Optional[bool] = None
) -> List[Choices]:
@@ -427,12 +443,19 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
choice["message"].get("content")
)
citations = DatabricksConfig.extract_citations(
choice["message"].get("content")
)
translated_message = Message(
role="assistant",
content=content_str,
reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks,
tool_calls=choice["message"].get("tool_calls"),
provider_specific_fields={"citations": citations}
if citations is not None
else None,
)
if finish_reason is None:
@@ -561,6 +584,17 @@ class DatabricksChatResponseIterator(BaseModelResponseIterator):
for _tc in tool_calls:
if _tc.get("function", {}).get("arguments") == "{}":
_tc["function"]["arguments"] = "" # avoid invalid json
if isinstance(choice["delta"]["content"], list) and (
content := choice["delta"]["content"]
):
if citations := content[0].get("citations"):
# TODO: Databricks delta does not include supported text or chunk type.
# Add either here once Databricks supports it to enable citation linkage.
choice["delta"].setdefault("provider_specific_fields", {})[
"citation"
] = citations[
0
] # Databricks Content item always has citation as a list of list
# extract the content str
content_str = DatabricksConfig.extract_content_str(
choice["delta"].get("content")
+6 -4
View File
@@ -1,5 +1,5 @@
import json
from typing import Any, List, Literal, Optional, TypedDict, Union
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
from pydantic import BaseModel
from typing_extensions import (
@@ -24,9 +24,10 @@ class GenericStreamingChunk(TypedDict, total=False):
usage: Optional[BaseModel]
class DatabricksTextContent(TypedDict):
class DatabricksTextContent(TypedDict, total=False):
type: Literal["text"]
text: Required[str]
citations: Optional[List[Dict[str, Any]]]
class DatabricksReasoningSummary(TypedDict):
@@ -35,9 +36,10 @@ class DatabricksReasoningSummary(TypedDict):
signature: str
class DatabricksReasoningContent(TypedDict):
class DatabricksReasoningContent(TypedDict, total=False):
type: Literal["reasoning"]
summary: List[DatabricksReasoningSummary]
summary: Required[List[DatabricksReasoningSummary]]
citations: Optional[List[Dict[str, Any]]]
AllDatabricksContentListValues = Union[
@@ -10,7 +10,10 @@ sys.path.insert(
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
from litellm.llms.databricks.chat.transformation import DatabricksConfig
from litellm.llms.databricks.chat.transformation import (
DatabricksChatResponseIterator,
DatabricksConfig,
)
def test_transform_choices():
@@ -85,8 +88,101 @@ def test_transform_choices_without_signature():
assert choices[0].message.reasoning_content == "i'm thinking without signature."
assert choices[0].message.thinking_blocks is not None
assert len(choices[0].message.thinking_blocks) == 1
# Verify the thinking block was created successfully without signature
thinking_block = choices[0].message.thinking_blocks[0]
assert thinking_block["type"] == "thinking"
assert thinking_block["thinking"] == "i'm thinking without signature."
def test_transform_choices_with_citations():
config = DatabricksConfig()
databricks_choices = [
{
"message": {
"role": "assistant",
"content": [
{
"type": "text",
"text": "Blue",
"citations": [
{
"type": "char_location",
"cited_text": "The sky is blue.",
"document_index": 0,
"document_title": "My Document",
"start_char_index": 0,
"end_char_index": 50,
}
],
}
],
},
"index": 0,
"finish_reason": "stop",
}
]
choices = config._transform_dbrx_choices(choices=databricks_choices)
assert choices[0].message.provider_specific_fields == {
"citations": [
[
{
"type": "char_location",
"cited_text": "The sky is blue.",
"document_index": 0,
"document_title": "My Document",
"start_char_index": 0,
"end_char_index": 50,
"supported_text": "Blue",
}
]
]
}
def test_chunk_parser_with_citation():
iterator = DatabricksChatResponseIterator(None, sync_stream=True)
chunk = {
"id": "1",
"object": "chat.completion.chunk",
"created": 0,
"model": "test",
"choices": [
{
"delta": {
"content": [
{
"type": "text",
"text": "",
"citations": [
{
"type": "char_location",
"cited_text": "The sky is blue.",
"document_index": 0,
"document_title": "My Document",
"start_char_index": 0,
"end_char_index": 50,
}
],
}
],
},
"index": 0,
"finish_reason": None,
}
],
}
parsed = iterator.chunk_parser(chunk)
assert parsed.choices[0].delta.provider_specific_fields == {
"citation": {
"type": "char_location",
"cited_text": "The sky is blue.",
"document_index": 0,
"document_title": "My Document",
"start_char_index": 0,
"end_char_index": 50,
}
}