mirror of
https://github.com/tiennm99/litellm.git
synced 2026-07-05 23:06:35 +00:00
(Refactor) Code Quality improvement - Use Common base handler for cloudflare/ provider (#7127)
* add get_complete_url to base config * cloudflare - refactor to following existing pattern * migrate cloudflare chat completions to base llm http handler * fix unused import * fix fake stream in cloudflare * fix cloudflare transformation * fix naming for BaseModelResponseIterator * add async cloudflare streaming test * test cloudflare * add handler.py * add handler.py in cohere handler.py
This commit is contained in:
+1
-1
@@ -1067,10 +1067,10 @@ from .llms.predibase import PredibaseConfig
|
||||
from .llms.replicate import ReplicateConfig
|
||||
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||
from .llms.cloudflare.chat.transformation import CloudflareChatConfig
|
||||
from .llms.ai21.completion import AI21Config
|
||||
from .llms.ai21.chat import AI21ChatConfig
|
||||
from .llms.together_ai.chat import TogetherAIConfig
|
||||
from .llms.cloudflare import CloudflareConfig
|
||||
from .llms.palm import PalmConfig
|
||||
from .llms.gemini import GeminiConfig
|
||||
from .llms.nlp_cloud import NLPCloudConfig
|
||||
|
||||
@@ -195,7 +195,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "cloudflare":
|
||||
return ["max_tokens", "stream"]
|
||||
return litellm.CloudflareChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "nlp_cloud":
|
||||
return [
|
||||
"max_tokens",
|
||||
|
||||
@@ -630,36 +630,6 @@ class CustomStreamWrapper:
|
||||
)
|
||||
return ""
|
||||
|
||||
def handle_cloudlfare_stream(self, chunk):
|
||||
try:
|
||||
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
|
||||
chunk = chunk.decode("utf-8")
|
||||
str_line = chunk
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
|
||||
if "[DONE]" in chunk:
|
||||
return {"text": text, "is_finished": True, "finish_reason": "stop"}
|
||||
elif str_line.startswith("data:"):
|
||||
data_json = json.loads(str_line[5:])
|
||||
print_verbose(f"delta content: {data_json}")
|
||||
text = data_json["response"]
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def handle_ollama_stream(self, chunk):
|
||||
try:
|
||||
if isinstance(chunk, dict):
|
||||
@@ -1226,12 +1196,6 @@ class CustomStreamWrapper:
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "cloudflare":
|
||||
response_obj = self.handle_cloudlfare_stream(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "watsonx":
|
||||
response_obj = self.handle_watsonx_stream(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
@@ -1722,6 +1686,7 @@ class CustomStreamWrapper:
|
||||
or self.custom_llm_provider == "bedrock"
|
||||
or self.custom_llm_provider == "triton"
|
||||
or self.custom_llm_provider == "watsonx"
|
||||
or self.custom_llm_provider == "cloudflare"
|
||||
or self.custom_llm_provider in litellm.openai_compatible_providers
|
||||
or self.custom_llm_provider in litellm._custom_providers
|
||||
):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
@@ -12,6 +12,103 @@ from litellm.types.utils import (
|
||||
)
|
||||
|
||||
|
||||
class BaseModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
|
||||
# chunk is a str at this point
|
||||
if "[DONE]" in str_line:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
elif str_line.startswith("data:"):
|
||||
data_json = json.loads(str_line[5:])
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
# chunk is a str at this point
|
||||
return self._handle_string_chunk(str_line=str_line)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
# chunk is a str at this point
|
||||
return self._handle_string_chunk(str_line=str_line)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
|
||||
class FakeStreamResponseIterator:
|
||||
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
||||
self.model_response = model_response
|
||||
|
||||
@@ -95,6 +95,16 @@ class BaseConfig(ABC):
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
def get_complete_url(self, api_base: str, model: str) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_request(
|
||||
self,
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
|
||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
|
||||
class CloudflareError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.cloudflare.com")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class CloudflareConfig:
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing CloudflareError API Key - A call is being made to cloudflare but no key is set either in the environment variables or via params"
|
||||
)
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": "Bearer " + api_key,
|
||||
}
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
custom_prompt_dict={},
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
|
||||
## Load Config
|
||||
config = litellm.CloudflareConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# cloudflare adds the model to the api base
|
||||
api_base = api_base + model
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"],
|
||||
)
|
||||
return response.iter_lines()
|
||||
else:
|
||||
response = requests.post(api_base, headers=headers, data=json.dumps(data))
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
if response.status_code != 200:
|
||||
raise CloudflareError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
completion_response = response.json()
|
||||
|
||||
model_response.choices[0].message.content = completion_response["result"][ # type: ignore
|
||||
"response"
|
||||
]
|
||||
|
||||
## CALCULATING USAGE
|
||||
print_verbose(
|
||||
f"CALCULATING CLOUDFLARE TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
|
||||
)
|
||||
prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model)
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = "cloudflare/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Cloudflare - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,202 @@
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class CloudflareError(BaseLLMException):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.cloudflare.com")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
request=self.request,
|
||||
response=self.response,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class CloudflareChatConfig(BaseConfig):
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing CloudflareError API Key - A call is being made to cloudflare but no key is set either in the environment variables or via params"
|
||||
)
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "apbplication/json",
|
||||
"Authorization": "Bearer " + api_key,
|
||||
}
|
||||
return headers
|
||||
|
||||
def get_complete_url(self, api_base: str, model: str) -> str:
|
||||
return api_base + model
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"stream",
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
elif param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
config = litellm.CloudflareChatConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
return data
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
completion_response = raw_response.json()
|
||||
|
||||
model_response.choices[0].message.content = completion_response["result"][ # type: ignore
|
||||
"response"
|
||||
]
|
||||
|
||||
prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model)
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = "cloudflare/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CloudflareError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
)
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return CloudflareChatResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class CloudflareChatResponseIterator(BaseModelResponseIterator):
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
|
||||
index = int(chunk.get("index", 0))
|
||||
|
||||
if "response" in chunk:
|
||||
text = chunk["response"]
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Cohere /generate API - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -13,7 +13,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
|
||||
|
||||
import litellm
|
||||
@@ -109,6 +108,11 @@ class BaseLLMHTTPHandler:
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
)
|
||||
|
||||
data = provider_config.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
||||
+9
-24
@@ -86,7 +86,6 @@ from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
||||
from .llms import (
|
||||
aleph_alpha,
|
||||
baseten,
|
||||
cloudflare,
|
||||
maritalk,
|
||||
nlp_cloud,
|
||||
ollama,
|
||||
@@ -471,6 +470,7 @@ async def acompletion(
|
||||
or custom_llm_provider == "triton"
|
||||
or custom_llm_provider == "clarifai"
|
||||
or custom_llm_provider == "watsonx"
|
||||
or custom_llm_provider == "cloudflare"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
or custom_llm_provider in litellm._custom_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
@@ -2828,37 +2828,22 @@ def completion( # type: ignore # noqa: PLR0915
|
||||
)
|
||||
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
response = cloudflare.completion(
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
custom_llm_provider="cloudflare",
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
response,
|
||||
model,
|
||||
custom_llm_provider="cloudflare",
|
||||
logging_obj=logging,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion is True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
)
|
||||
response = response
|
||||
elif (
|
||||
custom_llm_provider == "baseten"
|
||||
or litellm.api_base == "https://app.baseten.co"
|
||||
|
||||
+12
-4
@@ -3274,10 +3274,16 @@ def get_optional_params( # noqa: PLR0915
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
||||
if max_tokens is not None:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
optional_params = litellm.CloudflareChatConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "ollama":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
@@ -6248,6 +6254,8 @@ class ProviderConfigManager:
|
||||
elif litellm.LlmProviders.VERTEX_AI == provider:
|
||||
if "claude" in model:
|
||||
return litellm.VertexAIAnthropicConfig()
|
||||
elif litellm.LlmProviders.CLOUDFLARE == provider:
|
||||
return litellm.CloudflareChatConfig()
|
||||
|
||||
return litellm.OpenAIGPTConfig()
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||
|
||||
|
||||
# Cloud flare AI test
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
async def test_completion_cloudflare(stream):
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
response = await litellm.acompletion(
|
||||
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
||||
messages=[{"content": "what llm are you", "role": "user"}],
|
||||
max_tokens=15,
|
||||
stream=stream,
|
||||
)
|
||||
print(response)
|
||||
if stream is True:
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
else:
|
||||
print(response)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
@@ -4181,26 +4181,6 @@ def test_completion_together_ai_stream():
|
||||
# test_completion_together_ai_stream()
|
||||
|
||||
|
||||
# Cloud flare AI tests
|
||||
@pytest.mark.skip(reason="Flaky test-cloudflare is very unstable")
|
||||
def test_completion_cloudflare():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = completion(
|
||||
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
||||
messages=[{"content": "what llm are you", "role": "user"}],
|
||||
max_tokens=15,
|
||||
num_retries=3,
|
||||
)
|
||||
print(response)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_cloudflare()
|
||||
|
||||
|
||||
def test_moderation():
|
||||
response = litellm.moderation(input="i'm ishaan cto of litellm")
|
||||
print(response)
|
||||
|
||||
Reference in New Issue
Block a user