[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -77,7 +77,7 @@ def evaluate_json_response(model_response, golden_response):
|
||||
|
||||
|
||||
def generate(
|
||||
llm,
|
||||
llm: vllm.LLM,
|
||||
inputs: Tuple[str, SamplingParams, Optional[LoRARequest]],
|
||||
):
|
||||
prompts, sampling_param, lora_request = inputs
|
||||
@@ -159,7 +159,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
|
||||
non-batched generation.
|
||||
"""
|
||||
# Create non batched results first to compare against batched results
|
||||
non_batched_results = []
|
||||
non_batched_results: List[str] = []
|
||||
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
@@ -172,7 +172,8 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
|
||||
# Create batched results
|
||||
# Each element of the batch must be
|
||||
# (prompt, prompt_sampling_params, prompt_lora_request)
|
||||
batched_prompts = []
|
||||
batched_prompts: List[Tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]] = []
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
batched_prompts.extend([
|
||||
@@ -196,7 +197,8 @@ def test_self_consistency(lora_llm, long_context_infos):
|
||||
num_loras = len(long_context_infos)
|
||||
|
||||
# Create results in order of long_context_infos
|
||||
batched_prompts = []
|
||||
batched_prompts: List[Tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]] = []
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
batched_prompts.extend([
|
||||
@@ -244,7 +246,7 @@ def test_quality(lora_llm, long_context_infos):
|
||||
The test is expected to run for about 1 minute on a p4de.24xlarge
|
||||
instance.
|
||||
"""
|
||||
scores = []
|
||||
scores: List[float] = []
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
for prompt_and_response in prompts_and_responses[context_len]:
|
||||
@@ -277,7 +279,8 @@ def test_max_len(lora_llm, long_context_infos):
|
||||
generate(lora_llm, (bad_prompt, sampling_params, lora_request))
|
||||
|
||||
# Also test batched
|
||||
batched_prompts = []
|
||||
batched_prompts: List[Tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]] = []
|
||||
for lora_id_with_bad_inputs in long_context_infos:
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
|
||||
Reference in New Issue
Block a user