Fix latency benchmark script (#118)
This commit is contained in:
@@ -35,18 +35,26 @@ class LLM:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
) -> List[RequestOutput]:
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Processed prompts")
|
||||
|
||||
# Add requests to the server.
|
||||
for prompt in prompts:
|
||||
for i in range(len(prompts)):
|
||||
prompt = prompts[i]
|
||||
if prompt_token_ids is None:
|
||||
token_ids = None
|
||||
else:
|
||||
token_ids = prompt_token_ids[i]
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params)
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params,
|
||||
token_ids)
|
||||
|
||||
# Run the server.
|
||||
outputs: List[RequestOutput] = []
|
||||
|
||||
Reference in New Issue
Block a user