batch completions for vllm now works too

This commit is contained in:
Krrish Dholakia
2023-09-06 18:52:34 -07:00
parent 4a263f6ab7
commit 35cf6ef0a1
21 changed files with 149 additions and 23 deletions
Binary file not shown.
BIN
View File
Binary file not shown.
Binary file not shown.
BIN
View File
Binary file not shown.
Binary file not shown.
BIN
View File
Binary file not shown.
Binary file not shown.
BIN
View File
Binary file not shown.
Binary file not shown.
BIN
View File
Binary file not shown.
Binary file not shown.
BIN
View File
Binary file not shown.
Binary file not shown.
BIN
View File
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+85 -5
View File
@@ -6,7 +6,7 @@ import time
from typing import Callable
from litellm.utils import ModelResponse
from .prompt_templates.factory import prompt_factory, custom_prompt
llm = None
class VLLMError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@@ -16,10 +16,12 @@ class VLLMError(Exception):
) # Call the base class constructor with the parameters it needs
# check if vllm is installed
def validate_environment():
def validate_environment(model: str, llm: any=None):
try:
from vllm import LLM, SamplingParams
return LLM, SamplingParams
if llm is None:
llm = LLM(model=model)
return llm, SamplingParams
except:
raise VLLMError(status_code=0, message="The vllm package is not installed in your environment. Run - `pip install vllm` before proceeding.")
@@ -35,9 +37,8 @@ def completion(
litellm_params=None,
logger_fn=None,
):
LLM, SamplingParams = validate_environment()
try:
llm = LLM(model=model)
llm, SamplingParams = validate_environment(model=model)
except Exception as e:
raise VLLMError(status_code=0, message=str(e))
sampling_params = SamplingParams(**optional_params)
@@ -92,6 +93,85 @@ def completion(
}
return model_response
def batch_completions(
model: str,
messages: list,
optional_params=None,
custom_prompt_dict={}
):
"""
Example usage:
import litellm
import os
from litellm import batch_completion
responses = batch_completion(
model="vllm/facebook/opt-125m",
messages = [
[
{
"role": "user",
"content": "good morning? "
}
],
[
{
"role": "user",
"content": "what's the time? "
}
]
]
)
"""
global llm
try:
llm, SamplingParams = validate_environment(model=model, llm=llm)
except Exception as e:
if "data parallel group is already initialized" in e:
pass
else:
raise VLLMError(status_code=0, message=str(e))
sampling_params = SamplingParams(**optional_params)
prompts = []
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
for message in messages:
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=message
)
prompts.append(prompt)
else:
for message in messages:
prompt = prompt_factory(model=model, messages=message)
prompts.append(prompt)
outputs = llm.generate(prompts, sampling_params)
final_outputs = []
for output in outputs:
model_response = ModelResponse()
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = output.outputs[0].text
## CALCULATING USAGE
prompt_tokens = len(output.prompt_token_ids)
completion_tokens = len(output.outputs[0].token_ids)
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
final_outputs.append(model_response)
return final_outputs
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass
+61 -16
View File
@@ -693,7 +693,7 @@ def completion(
encoding=encoding,
logging_obj=logging
)
if "stream" in optional_params and optional_params["stream"] == True: ## [BETA]
# don't try to access stream object,
response = CustomStreamWrapper(
@@ -828,23 +828,68 @@ def completion_with_retries(*args, **kwargs):
return retryer(completion, *args, **kwargs)
def batch_completion(*args, **kwargs):
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
def batch_completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: List = [],
function_call: str = "", # optional params
temperature: float = 1,
top_p: float = 1,
n: int = 1,
stream: bool = False,
stop=None,
max_tokens: float = float("inf"),
presence_penalty: float = 0,
frequency_penalty=0,
logit_bias: dict = {},
user: str = "",
# used by text-bison only
top_k=40,
custom_llm_provider=None,):
args = locals()
batch_messages = messages
completions = []
with ThreadPoolExecutor() as executor:
for message_list in batch_messages:
if len(args) > 1:
args_modified = list(args)
args_modified[1] = message_list
future = executor.submit(completion, *args_modified)
else:
kwargs_modified = dict(kwargs)
kwargs_modified["messages"] = message_list
future = executor.submit(completion, *args, **kwargs_modified)
completions.append(future)
model = model
custom_llm_provider = None
if model.split("/", 1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
if custom_llm_provider == "vllm":
optional_params = get_optional_params(
functions=functions,
function_call=function_call,
temperature=temperature,
top_p=top_p,
n=n,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
# params to identify the model
model=model,
custom_llm_provider=custom_llm_provider,
top_k=top_k,
)
results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params)
else:
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
with ThreadPoolExecutor(max_workers=100) as executor:
for sub_batch in chunks(batch_messages, 100):
for message_list in sub_batch:
kwargs_modified = args
kwargs_modified["messages"] = message_list
future = executor.submit(completion, **kwargs_modified)
completions.append(future)
# Retrieve the results from the futures
results = [future.result() for future in completions]
# Retrieve the results from the futures
results = [future.result() for future in completions]
return results
+2 -1
View File
@@ -9,10 +9,11 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
from litellm import batch_completion
litellm.set_verbose=True
messages = [[{"role": "user", "content": "Hey, how's it going"}] for _ in range(5)]
print(messages[0:5])
print(len(messages))
# model = "vllm/facebook/opt-125m"
model = "gpt-3.5-turbo"
result = batch_completion(model=model, messages=messages)
+1 -1
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.549"
version = "0.1.555"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"