Add throughput benchmarking script (#133)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from tqdm import tqdm
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
@@ -31,6 +32,11 @@ class LLM:
|
||||
self.llm_server = LLMServer.from_server_args(server_args)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
return self.llm_server.tokenizer
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -41,10 +47,6 @@ class LLM:
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Processed prompts")
|
||||
|
||||
# Add requests to the server.
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
@@ -52,10 +54,24 @@ class LLM:
|
||||
token_ids = None
|
||||
else:
|
||||
token_ids = prompt_token_ids[i]
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params,
|
||||
token_ids)
|
||||
self._add_request(prompt, sampling_params, token_ids)
|
||||
return self._run_server(use_tqdm)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params,
|
||||
prompt_token_ids)
|
||||
|
||||
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_server.get_num_unfinished_requests()
|
||||
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
||||
# Run the server.
|
||||
outputs: List[RequestOutput] = []
|
||||
while self.llm_server.has_unfinished_requests():
|
||||
|
||||
Reference in New Issue
Block a user