Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -47,9 +47,9 @@ def vllm_model_apc(vllm_runner, monkeypatch):
|
||||
|
||||
|
||||
def _get_test_sampling_params(
|
||||
prompt_list: List[str],
|
||||
prompt_list: list[str],
|
||||
seed: Optional[int] = 42,
|
||||
) -> Tuple[List[SamplingParams], List[int]]:
|
||||
) -> tuple[list[SamplingParams], list[int]]:
|
||||
"""Generate random sampling params for a batch."""
|
||||
|
||||
def get_mostly_n_gt1() -> int:
|
||||
@@ -81,7 +81,7 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
|
||||
|
||||
# Validate each request response
|
||||
for out, n in zip(outputs, n_list):
|
||||
completion_counts: Dict[str, int] = {}
|
||||
completion_counts: dict[str, int] = {}
|
||||
# Assert correct number of completions
|
||||
assert len(out.outputs) == n, (
|
||||
f"{len(out.outputs)} completions; {n} expected.")
|
||||
|
||||
Reference in New Issue
Block a user