[mypy] Enable type checking for test directory (#5017)

This commit is contained in:
Cyrus Leung
2024-06-15 12:45:31 +08:00
committed by GitHub
parent 1b8a0d71cf
commit 0e9164b40a
92 changed files with 509 additions and 378 deletions

View File

@@ -118,16 +118,17 @@ class AsyncLLM:
raise ValueError("The lengths of prompts and "
"sampling_params must be the same.")
async def get_output(prompt, sampling_param) -> str:
async def get_output(prompt, sampling_param) -> RequestOutput:
request_id = random_uuid()
results_generator = self.llm_engine.generate(
prompt, sampling_param, request_id)
final_output = None
async for request_output in results_generator:
final_output = request_output
assert final_output is not None
return final_output
outputs = []
outputs: List[RequestOutput] = []
try:
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
@@ -208,8 +209,8 @@ def maybe_assert_ngram_worker(llm):
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
tokens = []
token_ids = []
tokens: List[str] = []
token_ids: List[List[int]] = []
for llm in llm_generator():
maybe_assert_ngram_worker(llm)
@@ -300,8 +301,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
nvmlInit()
start_time = time.time()
while True:
output = {}
output_raw = {}
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)