Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
Run `pytest tests/models/test_transformers.py`.
|
||||
"""
|
||||
from contextlib import nullcontext
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -14,8 +13,8 @@ from .utils import check_logprobs_close
|
||||
|
||||
|
||||
def check_implementation(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
example_prompts: list[str],
|
||||
model: str,
|
||||
**kwargs,
|
||||
@@ -47,8 +46,8 @@ def check_implementation(
|
||||
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
|
||||
]) # trust_remote_code=True by default
|
||||
def test_models(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
example_prompts: list[str],
|
||||
model: str,
|
||||
model_impl: str,
|
||||
@@ -71,8 +70,8 @@ def test_models(
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_distributed(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
example_prompts,
|
||||
):
|
||||
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
|
||||
@@ -92,7 +91,7 @@ def test_distributed(
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_quantization(
|
||||
vllm_runner: Type[VllmRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
example_prompts: list[str],
|
||||
model: str,
|
||||
quantization_kwargs: dict[str, str],
|
||||
|
||||
Reference in New Issue
Block a user