Add throughput benchmarking script (#133)

This commit is contained in:
Woosuk Kwon
2023-05-28 03:20:05 -07:00
committed by GitHub
parent 337871c6fd
commit 211318d44a
12 changed files with 145 additions and 257 deletions

View File

@@ -1,5 +1,6 @@
from typing import List, Optional
from typing import List, Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm
from cacheflow.outputs import RequestOutput
@@ -31,6 +32,11 @@ class LLM:
self.llm_server = LLMServer.from_server_args(server_args)
self.request_counter = Counter()
def get_tokenizer(
self,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_server.tokenizer
def generate(
self,
prompts: List[str],
@@ -41,10 +47,6 @@ class LLM:
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
# Initialize tqdm.
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Processed prompts")
# Add requests to the server.
for i in range(len(prompts)):
prompt = prompts[i]
@@ -52,10 +54,24 @@ class LLM:
token_ids = None
else:
token_ids = prompt_token_ids[i]
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params,
token_ids)
self._add_request(prompt, sampling_params, token_ids)
return self._run_server(use_tqdm)
def _add_request(
self,
prompt: str,
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
) -> None:
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params,
prompt_token_ids)
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_server.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the server.
outputs: List[RequestOutput] = []
while self.llm_server.has_unfinished_requests():