[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from itertools import count
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import TypeVar, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
@@ -14,6 +16,8 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
T = TypeVar("T", bound=Worker)
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
@@ -56,13 +60,13 @@ def zero_kv_cache(cache_engine: CacheEngine):
|
||||
value_blocks.zero_()
|
||||
|
||||
|
||||
def create_worker(cls: type,
|
||||
def create_worker(cls: Callable[..., T],
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
seed: int,
|
||||
is_driver_worker: bool = True,
|
||||
enforce_eager: bool = True):
|
||||
enforce_eager: bool = True) -> T:
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
seed=seed,
|
||||
@@ -159,8 +163,8 @@ def assert_logprobs_dict_allclose(
|
||||
|
||||
def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: Iterable[Optional[torch.Tensor]],
|
||||
logprobs: Iterable[Optional[torch.Tensor]],
|
||||
probs: GenericSequence[Optional[torch.Tensor]],
|
||||
logprobs: GenericSequence[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
|
||||
Reference in New Issue
Block a user