mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 00:48:01 +00:00
batch completions for vllm now works too
This commit is contained in:
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+85
-5
@@ -6,7 +6,7 @@ import time
|
||||
from typing import Callable
|
||||
from litellm.utils import ModelResponse
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
llm = None
|
||||
class VLLMError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
@@ -16,10 +16,12 @@ class VLLMError(Exception):
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
# check if vllm is installed
|
||||
def validate_environment():
|
||||
def validate_environment(model: str, llm: any=None):
|
||||
try:
|
||||
from vllm import LLM, SamplingParams
|
||||
return LLM, SamplingParams
|
||||
if llm is None:
|
||||
llm = LLM(model=model)
|
||||
return llm, SamplingParams
|
||||
except:
|
||||
raise VLLMError(status_code=0, message="The vllm package is not installed in your environment. Run - `pip install vllm` before proceeding.")
|
||||
|
||||
@@ -35,9 +37,8 @@ def completion(
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
LLM, SamplingParams = validate_environment()
|
||||
try:
|
||||
llm = LLM(model=model)
|
||||
llm, SamplingParams = validate_environment(model=model)
|
||||
except Exception as e:
|
||||
raise VLLMError(status_code=0, message=str(e))
|
||||
sampling_params = SamplingParams(**optional_params)
|
||||
@@ -92,6 +93,85 @@ def completion(
|
||||
}
|
||||
return model_response
|
||||
|
||||
def batch_completions(
|
||||
model: str,
|
||||
messages: list,
|
||||
optional_params=None,
|
||||
custom_prompt_dict={}
|
||||
):
|
||||
"""
|
||||
Example usage:
|
||||
import litellm
|
||||
import os
|
||||
from litellm import batch_completion
|
||||
|
||||
|
||||
responses = batch_completion(
|
||||
model="vllm/facebook/opt-125m",
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "good morning? "
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what's the time? "
|
||||
}
|
||||
]
|
||||
]
|
||||
)
|
||||
"""
|
||||
global llm
|
||||
try:
|
||||
llm, SamplingParams = validate_environment(model=model, llm=llm)
|
||||
except Exception as e:
|
||||
if "data parallel group is already initialized" in e:
|
||||
pass
|
||||
else:
|
||||
raise VLLMError(status_code=0, message=str(e))
|
||||
sampling_params = SamplingParams(**optional_params)
|
||||
prompts = []
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
for message in messages:
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=message
|
||||
)
|
||||
prompts.append(prompt)
|
||||
else:
|
||||
for message in messages:
|
||||
prompt = prompt_factory(model=model, messages=message)
|
||||
prompts.append(prompt)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
final_outputs = []
|
||||
for output in outputs:
|
||||
model_response = ModelResponse()
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = output.outputs[0].text
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(output.prompt_token_ids)
|
||||
completion_tokens = len(output.outputs[0].token_ids)
|
||||
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
final_outputs.append(model_response)
|
||||
return final_outputs
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
||||
+61
-16
@@ -693,7 +693,7 @@ def completion(
|
||||
encoding=encoding,
|
||||
logging_obj=logging
|
||||
)
|
||||
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True: ## [BETA]
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
@@ -828,23 +828,68 @@ def completion_with_retries(*args, **kwargs):
|
||||
return retryer(completion, *args, **kwargs)
|
||||
|
||||
|
||||
def batch_completion(*args, **kwargs):
|
||||
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
|
||||
def batch_completion(
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
functions: List = [],
|
||||
function_call: str = "", # optional params
|
||||
temperature: float = 1,
|
||||
top_p: float = 1,
|
||||
n: int = 1,
|
||||
stream: bool = False,
|
||||
stop=None,
|
||||
max_tokens: float = float("inf"),
|
||||
presence_penalty: float = 0,
|
||||
frequency_penalty=0,
|
||||
logit_bias: dict = {},
|
||||
user: str = "",
|
||||
# used by text-bison only
|
||||
top_k=40,
|
||||
custom_llm_provider=None,):
|
||||
args = locals()
|
||||
batch_messages = messages
|
||||
completions = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for message_list in batch_messages:
|
||||
if len(args) > 1:
|
||||
args_modified = list(args)
|
||||
args_modified[1] = message_list
|
||||
future = executor.submit(completion, *args_modified)
|
||||
else:
|
||||
kwargs_modified = dict(kwargs)
|
||||
kwargs_modified["messages"] = message_list
|
||||
future = executor.submit(completion, *args, **kwargs_modified)
|
||||
completions.append(future)
|
||||
model = model
|
||||
custom_llm_provider = None
|
||||
if model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if custom_llm_provider == "vllm":
|
||||
optional_params = get_optional_params(
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
user=user,
|
||||
# params to identify the model
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
top_k=top_k,
|
||||
)
|
||||
results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params)
|
||||
else:
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i + n]
|
||||
with ThreadPoolExecutor(max_workers=100) as executor:
|
||||
for sub_batch in chunks(batch_messages, 100):
|
||||
for message_list in sub_batch:
|
||||
kwargs_modified = args
|
||||
kwargs_modified["messages"] = message_list
|
||||
future = executor.submit(completion, **kwargs_modified)
|
||||
completions.append(future)
|
||||
|
||||
# Retrieve the results from the futures
|
||||
results = [future.result() for future in completions]
|
||||
# Retrieve the results from the futures
|
||||
results = [future.result() for future in completions]
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@@ -9,10 +9,11 @@ sys.path.insert(
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import batch_completion
|
||||
|
||||
litellm.set_verbose=True
|
||||
messages = [[{"role": "user", "content": "Hey, how's it going"}] for _ in range(5)]
|
||||
print(messages[0:5])
|
||||
print(len(messages))
|
||||
# model = "vllm/facebook/opt-125m"
|
||||
model = "gpt-3.5-turbo"
|
||||
|
||||
result = batch_completion(model=model, messages=messages)
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "0.1.549"
|
||||
version = "0.1.555"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
||||
Reference in New Issue
Block a user