[mypy] Enable type checking for test directory (#5017)

This commit is contained in:
Cyrus Leung
2024-06-15 12:45:31 +08:00
committed by GitHub
parent 1b8a0d71cf
commit 0e9164b40a
92 changed files with 509 additions and 378 deletions

View File

@@ -1,10 +1,10 @@
import copy
import weakref
from typing import List, Tuple
from typing import Dict, List, Tuple
import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
@@ -71,7 +71,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
sample_len)
# Run model sample_len times.
model_outputs = []
model_outputs: List[SamplerOutput] = []
for _ in range(sample_len):
model_output = super().execute_model(
execute_model_req=copied_execute_model_req)
@@ -132,7 +132,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list = []
new_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for old_seq_group_metadata in seq_group_metadata_list:
# We must shallow-copy seq_group_metadata as is_prompt could change.
@@ -140,7 +140,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
new_seq_group_metadata_list.append(seq_group_metadata)
# We must shallow-copy seq_data as we will append token ids
new_seq_data = {}
new_seq_data: Dict[int, SequenceData] = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[

View File

@@ -48,7 +48,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]:
) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
"""NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata.
@@ -58,8 +58,8 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self._raise_if_unsupported(execute_model_req)
has_spec_out = False
token_id_list = []
token_prob_list = []
token_id_list: List[Optional[torch.Tensor]] = []
token_prob_list: List[Optional[torch.Tensor]] = []
for idx, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))

View File

@@ -7,8 +7,8 @@ from vllm.config import SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
SamplerOutput, SequenceGroupMetadata)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
@@ -516,13 +516,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list = []
sampler_output_list: List[SamplerOutput] = []
for step_index in range(num_steps):
if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
for sequence_index in range(batch_size):
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]

View File

@@ -26,10 +26,10 @@ def get_all_num_logprobs(
sequence.
"""
all_num_logprobs = []
all_num_logprobs: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
num_logprobs = seq_group_metadata.sampling_params.logprobs
if seq_group_metadata.sampling_params.logprobs is None:
if num_logprobs is None:
num_logprobs = 0
all_num_logprobs.append(num_logprobs)