mirror of
https://github.com/tiennm99/litellm.git
synced 2026-07-03 17:08:43 +00:00
Merge pull request #14077 from TomeHirata/codex/add-support-for-anthropic-citation-api
Add support for anthropic citation api in Databricks
This commit is contained in:
@@ -282,6 +282,11 @@ ModelResponse(
|
||||
)
|
||||
```
|
||||
|
||||
### Citations
|
||||
|
||||
Anthropic models served through Databricks can return citation metadata. LiteLLM
|
||||
exposes these via `response.choices[0].message.provider_specific_fields["citations"]`.
|
||||
|
||||
### Pass `thinking` to Anthropic models
|
||||
|
||||
You can also pass the `thinking` parameter to Anthropic models.
|
||||
|
||||
@@ -26,7 +26,6 @@ from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response impo
|
||||
_should_convert_tool_call_to_json_mode,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
handle_messages_with_content_list_to_str_conversion,
|
||||
strip_name_from_messages,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
@@ -301,7 +300,6 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
|
||||
"""
|
||||
Databricks does not support:
|
||||
- content in list format.
|
||||
- 'name' in user message.
|
||||
"""
|
||||
new_messages = []
|
||||
@@ -311,7 +309,6 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||
else:
|
||||
_message = message
|
||||
new_messages.append(_message)
|
||||
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages)
|
||||
new_messages = strip_name_from_messages(new_messages)
|
||||
|
||||
if is_async:
|
||||
@@ -379,6 +376,25 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||
thinking_blocks.append(thinking_block)
|
||||
return reasoning_content, thinking_blocks
|
||||
|
||||
@staticmethod
|
||||
def extract_citations(
|
||||
content: Optional[AllDatabricksContentValues],
|
||||
) -> Optional[List[Any]]:
|
||||
if content is None:
|
||||
return None
|
||||
citations = []
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
text = item.get("text", None)
|
||||
if citations_item := item.get("citations"):
|
||||
citations.append(
|
||||
[
|
||||
{**citation, "supported_text": text}
|
||||
for citation in citations_item
|
||||
]
|
||||
)
|
||||
return citations or None
|
||||
|
||||
def _transform_dbrx_choices(
|
||||
self, choices: List[DatabricksChoice], json_mode: Optional[bool] = None
|
||||
) -> List[Choices]:
|
||||
@@ -427,12 +443,19 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||
choice["message"].get("content")
|
||||
)
|
||||
|
||||
citations = DatabricksConfig.extract_citations(
|
||||
choice["message"].get("content")
|
||||
)
|
||||
|
||||
translated_message = Message(
|
||||
role="assistant",
|
||||
content=content_str,
|
||||
reasoning_content=reasoning_content,
|
||||
thinking_blocks=thinking_blocks,
|
||||
tool_calls=choice["message"].get("tool_calls"),
|
||||
provider_specific_fields={"citations": citations}
|
||||
if citations is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
if finish_reason is None:
|
||||
@@ -561,6 +584,17 @@ class DatabricksChatResponseIterator(BaseModelResponseIterator):
|
||||
for _tc in tool_calls:
|
||||
if _tc.get("function", {}).get("arguments") == "{}":
|
||||
_tc["function"]["arguments"] = "" # avoid invalid json
|
||||
if isinstance(choice["delta"]["content"], list) and (
|
||||
content := choice["delta"]["content"]
|
||||
):
|
||||
if citations := content[0].get("citations"):
|
||||
# TODO: Databricks delta does not include supported text or chunk type.
|
||||
# Add either here once Databricks supports it to enable citation linkage.
|
||||
choice["delta"].setdefault("provider_specific_fields", {})[
|
||||
"citation"
|
||||
] = citations[
|
||||
0
|
||||
] # Databricks Content item always has citation as a list of list
|
||||
# extract the content str
|
||||
content_str = DatabricksConfig.extract_content_str(
|
||||
choice["delta"].get("content")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, List, Literal, Optional, TypedDict, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import (
|
||||
@@ -24,9 +24,10 @@ class GenericStreamingChunk(TypedDict, total=False):
|
||||
usage: Optional[BaseModel]
|
||||
|
||||
|
||||
class DatabricksTextContent(TypedDict):
|
||||
class DatabricksTextContent(TypedDict, total=False):
|
||||
type: Literal["text"]
|
||||
text: Required[str]
|
||||
citations: Optional[List[Dict[str, Any]]]
|
||||
|
||||
|
||||
class DatabricksReasoningSummary(TypedDict):
|
||||
@@ -35,9 +36,10 @@ class DatabricksReasoningSummary(TypedDict):
|
||||
signature: str
|
||||
|
||||
|
||||
class DatabricksReasoningContent(TypedDict):
|
||||
class DatabricksReasoningContent(TypedDict, total=False):
|
||||
type: Literal["reasoning"]
|
||||
summary: List[DatabricksReasoningSummary]
|
||||
summary: Required[List[DatabricksReasoningSummary]]
|
||||
citations: Optional[List[Dict[str, Any]]]
|
||||
|
||||
|
||||
AllDatabricksContentListValues = Union[
|
||||
|
||||
@@ -10,7 +10,10 @@ sys.path.insert(
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm.llms.databricks.chat.transformation import DatabricksConfig
|
||||
from litellm.llms.databricks.chat.transformation import (
|
||||
DatabricksChatResponseIterator,
|
||||
DatabricksConfig,
|
||||
)
|
||||
|
||||
|
||||
def test_transform_choices():
|
||||
@@ -85,8 +88,101 @@ def test_transform_choices_without_signature():
|
||||
assert choices[0].message.reasoning_content == "i'm thinking without signature."
|
||||
assert choices[0].message.thinking_blocks is not None
|
||||
assert len(choices[0].message.thinking_blocks) == 1
|
||||
|
||||
|
||||
# Verify the thinking block was created successfully without signature
|
||||
thinking_block = choices[0].message.thinking_blocks[0]
|
||||
assert thinking_block["type"] == "thinking"
|
||||
assert thinking_block["thinking"] == "i'm thinking without signature."
|
||||
|
||||
|
||||
def test_transform_choices_with_citations():
|
||||
config = DatabricksConfig()
|
||||
databricks_choices = [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Blue",
|
||||
"citations": [
|
||||
{
|
||||
"type": "char_location",
|
||||
"cited_text": "The sky is blue.",
|
||||
"document_index": 0,
|
||||
"document_title": "My Document",
|
||||
"start_char_index": 0,
|
||||
"end_char_index": 50,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
|
||||
choices = config._transform_dbrx_choices(choices=databricks_choices)
|
||||
|
||||
assert choices[0].message.provider_specific_fields == {
|
||||
"citations": [
|
||||
[
|
||||
{
|
||||
"type": "char_location",
|
||||
"cited_text": "The sky is blue.",
|
||||
"document_index": 0,
|
||||
"document_title": "My Document",
|
||||
"start_char_index": 0,
|
||||
"end_char_index": 50,
|
||||
"supported_text": "Blue",
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_chunk_parser_with_citation():
|
||||
iterator = DatabricksChatResponseIterator(None, sync_stream=True)
|
||||
chunk = {
|
||||
"id": "1",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": "test",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
"citations": [
|
||||
{
|
||||
"type": "char_location",
|
||||
"cited_text": "The sky is blue.",
|
||||
"document_index": 0,
|
||||
"document_title": "My Document",
|
||||
"start_char_index": 0,
|
||||
"end_char_index": 50,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
parsed = iterator.chunk_parser(chunk)
|
||||
assert parsed.choices[0].delta.provider_specific_fields == {
|
||||
"citation": {
|
||||
"type": "char_location",
|
||||
"cited_text": "The sky is blue.",
|
||||
"document_index": 0,
|
||||
"document_title": "My Document",
|
||||
"start_char_index": 0,
|
||||
"end_char_index": 50,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user