[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -118,16 +118,17 @@ class AsyncLLM:
|
||||
raise ValueError("The lengths of prompts and "
|
||||
"sampling_params must be the same.")
|
||||
|
||||
async def get_output(prompt, sampling_param) -> str:
|
||||
async def get_output(prompt, sampling_param) -> RequestOutput:
|
||||
request_id = random_uuid()
|
||||
results_generator = self.llm_engine.generate(
|
||||
prompt, sampling_param, request_id)
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
assert final_output is not None
|
||||
return final_output
|
||||
|
||||
outputs = []
|
||||
outputs: List[RequestOutput] = []
|
||||
try:
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
@@ -208,8 +209,8 @@ def maybe_assert_ngram_worker(llm):
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||
tokens = []
|
||||
token_ids = []
|
||||
tokens: List[str] = []
|
||||
token_ids: List[List[int]] = []
|
||||
for llm in llm_generator():
|
||||
maybe_assert_ngram_worker(llm)
|
||||
|
||||
@@ -300,8 +301,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||
nvmlInit()
|
||||
start_time = time.time()
|
||||
while True:
|
||||
output = {}
|
||||
output_raw = {}
|
||||
output: Dict[int, str] = {}
|
||||
output_raw: Dict[int, float] = {}
|
||||
for device in devices:
|
||||
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
||||
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -38,14 +40,14 @@ def test_get_token_ids_to_score(k: int):
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
expected_output = [
|
||||
expected_output: List[List[int]] = [
|
||||
[],
|
||||
]
|
||||
for i in range(proposal_token_ids.shape[0]):
|
||||
expected_output.append(proposal_token_ids[:i + 1].tolist())
|
||||
|
||||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
|
||||
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access
|
||||
actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
|
||||
|
||||
actual_output = [
|
||||
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import random
|
||||
from typing import Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
@@ -210,7 +211,7 @@ def test_same_output_for_multi_step():
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output = []
|
||||
single_step_output: List[SamplerOutput] = []
|
||||
continuations = [[1] for _ in prompts]
|
||||
set_random_seed(seed)
|
||||
|
||||
@@ -232,11 +233,15 @@ def test_same_output_for_multi_step():
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Get token ids and logprobs for comparison.
|
||||
multi_step_output_logprobs = [[] for _ in prompts]
|
||||
single_step_output_logprobs = [[] for _ in prompts]
|
||||
multi_step_output_logprobs: List[List[Dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
single_step_output_logprobs: List[List[Dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
|
||||
multi_step_output_token_ids = [[] for _ in prompts]
|
||||
single_step_output_token_ids = [[] for _ in prompts]
|
||||
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
|
||||
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
|
||||
for i, _ in enumerate(prompts):
|
||||
for multi_step, single_step in zip(multi_step_output,
|
||||
single_step_output):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import random
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -7,7 +8,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
@@ -103,7 +104,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
seen_contexts = []
|
||||
seen_contexts: List[List[int]] = []
|
||||
|
||||
call_args_list = target_worker.execute_model.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
@@ -116,7 +117,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seen_contexts.append(seq_data.get_token_ids())
|
||||
|
||||
expected_seen_contexts = []
|
||||
expected_seen_contexts: List[List[int]] = []
|
||||
|
||||
for prompt, prev_generated, draft_tokens in zip(
|
||||
prompts, prev_output_tokens, proposal_token_ids.tolist()):
|
||||
@@ -310,8 +311,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]
|
||||
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||
actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
|
||||
for step in output:
|
||||
for seq_group in step:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from itertools import count
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import TypeVar, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
@@ -14,6 +16,8 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
T = TypeVar("T", bound=Worker)
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
@@ -56,13 +60,13 @@ def zero_kv_cache(cache_engine: CacheEngine):
|
||||
value_blocks.zero_()
|
||||
|
||||
|
||||
def create_worker(cls: type,
|
||||
def create_worker(cls: Callable[..., T],
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
seed: int,
|
||||
is_driver_worker: bool = True,
|
||||
enforce_eager: bool = True):
|
||||
enforce_eager: bool = True) -> T:
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
seed=seed,
|
||||
@@ -159,8 +163,8 @@ def assert_logprobs_dict_allclose(
|
||||
|
||||
def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: Iterable[Optional[torch.Tensor]],
|
||||
logprobs: Iterable[Optional[torch.Tensor]],
|
||||
probs: GenericSequence[Optional[torch.Tensor]],
|
||||
logprobs: GenericSequence[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
|
||||
Reference in New Issue
Block a user