diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index e55c160e01..67738a7f1c 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -191,6 +191,10 @@ def completion( - `top_logprobs`: *int (optional)* - An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to true if this parameter is used. +- `headers`: *dict (optional)* - A dictionary of headers to be sent with the request. + +- `extra_headers`: *dict (optional)* - Alternative to `headers`, used to send extra headers in LLM API request. + #### Deprecated Params - `functions`: *array* - A list of functions that the model may use to generate JSON inputs. Each function should have the following properties: diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index b1f6ef3da9..9536f38547 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -378,6 +378,12 @@ class AmazonConverseConfig: for key in additional_request_keys: inference_params.pop(key, None) + if 'topK' in inference_params: + additional_request_params["inferenceConfig"] = {'topK': inference_params.pop("topK")} + elif 'top_k' in inference_params: + additional_request_params["inferenceConfig"] = {'topK': inference_params.pop("top_k")} + + bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( inference_params.pop("tools", []) ) diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 6a59d813b9..d466a13553 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2215,8 +2215,22 @@ def test_bedrock_nova_topk(top_k_param): "messages": [{"role": "user", "content": "Hello, world!"}], top_k_param: 10, } - litellm.completion(**data) + original_transform = litellm.AmazonConverseConfig()._transform_request + captured_data = None + def mock_transform(*args, **kwargs): + nonlocal captured_data + result = original_transform(*args, **kwargs) + captured_data = result + return result + + with patch('litellm.AmazonConverseConfig._transform_request', side_effect=mock_transform): + litellm.completion(**data) + + # Assert that additionalRequestParameters exists and contains topK + assert 'additionalModelRequestFields' in captured_data + assert 'inferenceConfig' in captured_data['additionalModelRequestFields'] + assert captured_data['additionalModelRequestFields']['inferenceConfig']['topK'] == 10 def test_bedrock_process_empty_text_blocks(): from litellm.litellm_core_utils.prompt_templates.factory import (