[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import random
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -7,7 +8,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
@@ -103,7 +104,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
seen_contexts = []
|
||||
seen_contexts: List[List[int]] = []
|
||||
|
||||
call_args_list = target_worker.execute_model.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
@@ -116,7 +117,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seen_contexts.append(seq_data.get_token_ids())
|
||||
|
||||
expected_seen_contexts = []
|
||||
expected_seen_contexts: List[List[int]] = []
|
||||
|
||||
for prompt, prev_generated, draft_tokens in zip(
|
||||
prompts, prev_output_tokens, proposal_token_ids.tolist()):
|
||||
@@ -310,8 +311,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]
|
||||
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||
actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
|
||||
for step in output:
|
||||
for seq_group in step:
|
||||
|
||||
Reference in New Issue
Block a user