Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import ast
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -86,7 +86,7 @@ def evaluate_json_response(model_response, golden_response):
|
||||
|
||||
def generate(
|
||||
llm: vllm.LLM,
|
||||
inputs: Tuple[str, SamplingParams, Optional[LoRARequest]],
|
||||
inputs: tuple[str, SamplingParams, Optional[LoRARequest]],
|
||||
):
|
||||
prompts, sampling_param, lora_request = inputs
|
||||
outputs = llm.generate(prompts, sampling_param, lora_request=lora_request)
|
||||
@@ -95,7 +95,7 @@ def generate(
|
||||
|
||||
def batched_generate(
|
||||
llm: vllm.LLM,
|
||||
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
|
||||
inputs: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
|
||||
):
|
||||
for input in inputs:
|
||||
prompt, sampling_param, lora_req = input
|
||||
@@ -164,7 +164,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
|
||||
non-batched generation.
|
||||
"""
|
||||
# Create non batched results first to compare against batched results
|
||||
non_batched_results: List[str] = []
|
||||
non_batched_results: list[str] = []
|
||||
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
@@ -177,7 +177,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
|
||||
# Create batched results
|
||||
# Each element of the batch must be
|
||||
# (prompt, prompt_sampling_params, prompt_lora_request)
|
||||
batched_prompts: List[Tuple[str, SamplingParams,
|
||||
batched_prompts: list[tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]] = []
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
@@ -202,7 +202,7 @@ def test_self_consistency(lora_llm, long_context_infos):
|
||||
num_loras = len(long_context_infos)
|
||||
|
||||
# Create results in order of long_context_infos
|
||||
batched_prompts: List[Tuple[str, SamplingParams,
|
||||
batched_prompts: list[tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]] = []
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
@@ -251,7 +251,7 @@ def test_quality(lora_llm, long_context_infos):
|
||||
The test is expected to run for about 1 minute on a p4de.24xlarge
|
||||
instance.
|
||||
"""
|
||||
scores: List[float] = []
|
||||
scores: list[float] = []
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
for prompt_and_response in prompts_and_responses[context_len]:
|
||||
@@ -284,7 +284,7 @@ def test_max_len(lora_llm, long_context_infos):
|
||||
generate(lora_llm, (bad_prompt, sampling_params, lora_request))
|
||||
|
||||
# Also test batched
|
||||
batched_prompts: List[Tuple[str, SamplingParams,
|
||||
batched_prompts: list[tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]] = []
|
||||
for lora_id_with_bad_inputs in long_context_infos:
|
||||
for lora_id, info in long_context_infos.items():
|
||||
|
||||
Reference in New Issue
Block a user