[mypy] Enable type checking for test directory (#5017)

This commit is contained in:
Cyrus Leung
2024-06-15 12:45:31 +08:00
committed by GitHub
parent 1b8a0d71cf
commit 0e9164b40a
92 changed files with 509 additions and 378 deletions

View File

@@ -77,7 +77,7 @@ def evaluate_json_response(model_response, golden_response):
def generate(
llm,
llm: vllm.LLM,
inputs: Tuple[str, SamplingParams, Optional[LoRARequest]],
):
prompts, sampling_param, lora_request = inputs
@@ -159,7 +159,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
non-batched generation.
"""
# Create non batched results first to compare against batched results
non_batched_results = []
non_batched_results: List[str] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
@@ -172,7 +172,8 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
# Create batched results
# Each element of the batch must be
# (prompt, prompt_sampling_params, prompt_lora_request)
batched_prompts = []
batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
@@ -196,7 +197,8 @@ def test_self_consistency(lora_llm, long_context_infos):
num_loras = len(long_context_infos)
# Create results in order of long_context_infos
batched_prompts = []
batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
@@ -244,7 +246,7 @@ def test_quality(lora_llm, long_context_infos):
The test is expected to run for about 1 minute on a p4de.24xlarge
instance.
"""
scores = []
scores: List[float] = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
for prompt_and_response in prompts_and_responses[context_len]:
@@ -277,7 +279,8 @@ def test_max_len(lora_llm, long_context_infos):
generate(lora_llm, (bad_prompt, sampling_params, lora_request))
# Also test batched
batched_prompts = []
batched_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]] = []
for lora_id_with_bad_inputs in long_context_infos:
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]