Fix latency benchmark script (#118)

This commit is contained in:
Woosuk Kwon
2023-05-22 17:03:40 -07:00
committed by GitHub
parent 19d2899439
commit 3f942acfe1
2 changed files with 43 additions and 31 deletions

View File

@@ -35,18 +35,26 @@ class LLM:
self,
prompts: List[str],
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
# Initialize tqdm.
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Processed prompts")
# Add requests to the server.
for prompt in prompts:
for i in range(len(prompts)):
prompt = prompts[i]
if prompt_token_ids is None:
token_ids = None
else:
token_ids = prompt_token_ids[i]
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params)
self.llm_server.add_request(request_id, prompt, sampling_params,
token_ids)
# Run the server.
outputs: List[RequestOutput] = []