[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import copy
|
||||
import weakref
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
@@ -71,7 +71,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
sample_len)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs = []
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
for _ in range(sample_len):
|
||||
model_output = super().execute_model(
|
||||
execute_model_req=copied_execute_model_req)
|
||||
@@ -132,7 +132,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
|
||||
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
||||
# append tokens and change is_prompt without external side-effects.
|
||||
new_seq_group_metadata_list = []
|
||||
new_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
|
||||
for old_seq_group_metadata in seq_group_metadata_list:
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
@@ -140,7 +140,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
new_seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data = {}
|
||||
new_seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[
|
||||
|
||||
@@ -48,7 +48,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
|
||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||
sampler output, one per SequenceGroupMetadata.
|
||||
|
||||
@@ -58,8 +58,8 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
has_spec_out = False
|
||||
token_id_list = []
|
||||
token_prob_list = []
|
||||
token_id_list: List[Optional[torch.Tensor]] = []
|
||||
token_prob_list: List[Optional[torch.Tensor]] = []
|
||||
for idx, seq_group_metadata in enumerate(
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
|
||||
@@ -7,8 +7,8 @@ from vllm.config import SpeculativeConfig
|
||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
SamplerOutput, SequenceGroupMetadata)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
@@ -516,13 +516,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
topk_indices_by_step = topk_indices_by_step.tolist()
|
||||
|
||||
# Construct the output on a per-step, per-sequence basis.
|
||||
sampler_output_list = []
|
||||
sampler_output_list: List[SamplerOutput] = []
|
||||
for step_index in range(num_steps):
|
||||
if all(token_id == -1
|
||||
for token_id in accepted_token_ids_by_step[step_index]):
|
||||
break
|
||||
|
||||
step_output_token_ids = []
|
||||
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
|
||||
for sequence_index in range(batch_size):
|
||||
# Each sequence may have a different num_logprobs; retrieve it.
|
||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||
|
||||
@@ -26,10 +26,10 @@ def get_all_num_logprobs(
|
||||
sequence.
|
||||
"""
|
||||
|
||||
all_num_logprobs = []
|
||||
all_num_logprobs: List[int] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
num_logprobs = seq_group_metadata.sampling_params.logprobs
|
||||
if seq_group_metadata.sampling_params.logprobs is None:
|
||||
if num_logprobs is None:
|
||||
num_logprobs = 0
|
||||
all_num_logprobs.append(num_logprobs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user