[mypy] Enable type checking for test directory (#5017)

This commit is contained in:
Cyrus Leung
2024-06-15 12:45:31 +08:00
committed by GitHub
parent 1b8a0d71cf
commit 0e9164b40a
92 changed files with 509 additions and 378 deletions

View File

@@ -118,16 +118,17 @@ class AsyncLLM:
raise ValueError("The lengths of prompts and "
"sampling_params must be the same.")
async def get_output(prompt, sampling_param) -> str:
async def get_output(prompt, sampling_param) -> RequestOutput:
request_id = random_uuid()
results_generator = self.llm_engine.generate(
prompt, sampling_param, request_id)
final_output = None
async for request_output in results_generator:
final_output = request_output
assert final_output is not None
return final_output
outputs = []
outputs: List[RequestOutput] = []
try:
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
@@ -208,8 +209,8 @@ def maybe_assert_ngram_worker(llm):
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
tokens = []
token_ids = []
tokens: List[str] = []
token_ids: List[List[int]] = []
for llm in llm_generator():
maybe_assert_ngram_worker(llm)
@@ -300,8 +301,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
nvmlInit()
start_time = time.time()
while True:
output = {}
output_raw = {}
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)

View File

@@ -1,3 +1,5 @@
from typing import List
import pytest
import torch
@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int):
device='cuda',
)
expected_output = [
expected_output: List[List[int]] = [
[],
]
for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist())
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access
actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output

View File

@@ -1,11 +1,12 @@
import random
from typing import Dict, List
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
@@ -210,7 +211,7 @@ def test_same_output_for_multi_step():
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
single_step_output = []
single_step_output: List[SamplerOutput] = []
continuations = [[1] for _ in prompts]
set_random_seed(seed)
@@ -232,11 +233,15 @@ def test_same_output_for_multi_step():
continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs = [[] for _ in prompts]
single_step_output_logprobs = [[] for _ in prompts]
multi_step_output_logprobs: List[List[Dict[int,
Logprob]]] = [[]
for _ in prompts]
single_step_output_logprobs: List[List[Dict[int,
Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids = [[] for _ in prompts]
single_step_output_token_ids = [[] for _ in prompts]
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output,
single_step_output):

View File

@@ -1,5 +1,6 @@
import random
from types import SimpleNamespace
from typing import Dict, List
from unittest.mock import MagicMock
import pytest
@@ -7,7 +8,7 @@ import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
@@ -103,7 +104,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
seen_contexts = []
seen_contexts: List[List[int]] = []
call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1
@@ -116,7 +117,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts = []
expected_seen_contexts: List[List[int]] = []
for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()):
@@ -310,8 +311,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in seq_group_metadata_list
]
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
for step in output:
for seq_group in step:

View File

@@ -1,5 +1,7 @@
from itertools import count
from typing import Dict, Iterable, List, Optional, Union
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import TypeVar, Union
from unittest.mock import MagicMock
import torch
@@ -14,6 +16,8 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
T = TypeVar("T", bound=Worker)
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size
@@ -56,13 +60,13 @@ def zero_kv_cache(cache_engine: CacheEngine):
value_blocks.zero_()
def create_worker(cls: type,
def create_worker(cls: Callable[..., T],
model_name: str,
block_size: int,
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True):
enforce_eager: bool = True) -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
@@ -159,8 +163,8 @@ def assert_logprobs_dict_allclose(
def create_sampler_output_list(
token_ids: torch.Tensor,
probs: Iterable[Optional[torch.Tensor]],
logprobs: Iterable[Optional[torch.Tensor]],
probs: GenericSequence[Optional[torch.Tensor]],
logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()