diff --git a/docs/my-website/docs/providers/databricks.md b/docs/my-website/docs/providers/databricks.md index 8631cbfdad..921b06a17b 100644 --- a/docs/my-website/docs/providers/databricks.md +++ b/docs/my-website/docs/providers/databricks.md @@ -282,6 +282,11 @@ ModelResponse( ) ``` +### Citations + +Anthropic models served through Databricks can return citation metadata. LiteLLM +exposes these via `response.choices[0].message.provider_specific_fields["citations"]`. + ### Pass `thinking` to Anthropic models You can also pass the `thinking` parameter to Anthropic models. diff --git a/litellm/llms/databricks/chat/transformation.py b/litellm/llms/databricks/chat/transformation.py index 908419f719..d3df5bbf36 100644 --- a/litellm/llms/databricks/chat/transformation.py +++ b/litellm/llms/databricks/chat/transformation.py @@ -26,7 +26,6 @@ from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response impo _should_convert_tool_call_to_json_mode, ) from litellm.litellm_core_utils.prompt_templates.common_utils import ( - handle_messages_with_content_list_to_str_conversion, strip_name_from_messages, ) from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator @@ -301,7 +300,6 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig): ) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]: """ Databricks does not support: - - content in list format. - 'name' in user message. """ new_messages = [] @@ -311,7 +309,6 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig): else: _message = message new_messages.append(_message) - new_messages = handle_messages_with_content_list_to_str_conversion(new_messages) new_messages = strip_name_from_messages(new_messages) if is_async: @@ -379,6 +376,25 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig): thinking_blocks.append(thinking_block) return reasoning_content, thinking_blocks + @staticmethod + def extract_citations( + content: Optional[AllDatabricksContentValues], + ) -> Optional[List[Any]]: + if content is None: + return None + citations = [] + if isinstance(content, list): + for item in content: + text = item.get("text", None) + if citations_item := item.get("citations"): + citations.append( + [ + {**citation, "supported_text": text} + for citation in citations_item + ] + ) + return citations or None + def _transform_dbrx_choices( self, choices: List[DatabricksChoice], json_mode: Optional[bool] = None ) -> List[Choices]: @@ -427,12 +443,19 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig): choice["message"].get("content") ) + citations = DatabricksConfig.extract_citations( + choice["message"].get("content") + ) + translated_message = Message( role="assistant", content=content_str, reasoning_content=reasoning_content, thinking_blocks=thinking_blocks, tool_calls=choice["message"].get("tool_calls"), + provider_specific_fields={"citations": citations} + if citations is not None + else None, ) if finish_reason is None: @@ -561,6 +584,17 @@ class DatabricksChatResponseIterator(BaseModelResponseIterator): for _tc in tool_calls: if _tc.get("function", {}).get("arguments") == "{}": _tc["function"]["arguments"] = "" # avoid invalid json + if isinstance(choice["delta"]["content"], list) and ( + content := choice["delta"]["content"] + ): + if citations := content[0].get("citations"): + # TODO: Databricks delta does not include supported text or chunk type. + # Add either here once Databricks supports it to enable citation linkage. + choice["delta"].setdefault("provider_specific_fields", {})[ + "citation" + ] = citations[ + 0 + ] # Databricks Content item always has citation as a list of list # extract the content str content_str = DatabricksConfig.extract_content_str( choice["delta"].get("content") diff --git a/litellm/types/llms/databricks.py b/litellm/types/llms/databricks.py index bb59b692ef..112427c6b5 100644 --- a/litellm/types/llms/databricks.py +++ b/litellm/types/llms/databricks.py @@ -1,5 +1,5 @@ import json -from typing import Any, List, Literal, Optional, TypedDict, Union +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union from pydantic import BaseModel from typing_extensions import ( @@ -24,9 +24,10 @@ class GenericStreamingChunk(TypedDict, total=False): usage: Optional[BaseModel] -class DatabricksTextContent(TypedDict): +class DatabricksTextContent(TypedDict, total=False): type: Literal["text"] text: Required[str] + citations: Optional[List[Dict[str, Any]]] class DatabricksReasoningSummary(TypedDict): @@ -35,9 +36,10 @@ class DatabricksReasoningSummary(TypedDict): signature: str -class DatabricksReasoningContent(TypedDict): +class DatabricksReasoningContent(TypedDict, total=False): type: Literal["reasoning"] - summary: List[DatabricksReasoningSummary] + summary: Required[List[DatabricksReasoningSummary]] + citations: Optional[List[Dict[str, Any]]] AllDatabricksContentListValues = Union[ diff --git a/tests/test_litellm/llms/databricks/chat/test_databricks_chat_transformation.py b/tests/test_litellm/llms/databricks/chat/test_databricks_chat_transformation.py index fc44d44aba..51a2e971c0 100644 --- a/tests/test_litellm/llms/databricks/chat/test_databricks_chat_transformation.py +++ b/tests/test_litellm/llms/databricks/chat/test_databricks_chat_transformation.py @@ -10,7 +10,10 @@ sys.path.insert( ) # Adds the parent directory to the system path from unittest.mock import MagicMock, patch -from litellm.llms.databricks.chat.transformation import DatabricksConfig +from litellm.llms.databricks.chat.transformation import ( + DatabricksChatResponseIterator, + DatabricksConfig, +) def test_transform_choices(): @@ -85,8 +88,101 @@ def test_transform_choices_without_signature(): assert choices[0].message.reasoning_content == "i'm thinking without signature." assert choices[0].message.thinking_blocks is not None assert len(choices[0].message.thinking_blocks) == 1 - + # Verify the thinking block was created successfully without signature thinking_block = choices[0].message.thinking_blocks[0] assert thinking_block["type"] == "thinking" assert thinking_block["thinking"] == "i'm thinking without signature." + + +def test_transform_choices_with_citations(): + config = DatabricksConfig() + databricks_choices = [ + { + "message": { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "Blue", + "citations": [ + { + "type": "char_location", + "cited_text": "The sky is blue.", + "document_index": 0, + "document_title": "My Document", + "start_char_index": 0, + "end_char_index": 50, + } + ], + } + ], + }, + "index": 0, + "finish_reason": "stop", + } + ] + + choices = config._transform_dbrx_choices(choices=databricks_choices) + + assert choices[0].message.provider_specific_fields == { + "citations": [ + [ + { + "type": "char_location", + "cited_text": "The sky is blue.", + "document_index": 0, + "document_title": "My Document", + "start_char_index": 0, + "end_char_index": 50, + "supported_text": "Blue", + } + ] + ] + } + + +def test_chunk_parser_with_citation(): + iterator = DatabricksChatResponseIterator(None, sync_stream=True) + chunk = { + "id": "1", + "object": "chat.completion.chunk", + "created": 0, + "model": "test", + "choices": [ + { + "delta": { + "content": [ + { + "type": "text", + "text": "", + "citations": [ + { + "type": "char_location", + "cited_text": "The sky is blue.", + "document_index": 0, + "document_title": "My Document", + "start_char_index": 0, + "end_char_index": 50, + } + ], + } + ], + }, + "index": 0, + "finish_reason": None, + } + ], + } + + parsed = iterator.chunk_parser(chunk) + assert parsed.choices[0].delta.provider_specific_fields == { + "citation": { + "type": "char_location", + "cited_text": "The sky is blue.", + "document_index": 0, + "document_title": "My Document", + "start_char_index": 0, + "end_char_index": 50, + } + }