[CI] Enable mypy import following for vllm/spec_decode (#33282)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-01-29 22:43:32 -08:00
committed by GitHub
parent d334dd26c4
commit 726d89720c
5 changed files with 38 additions and 20 deletions

View File

@@ -32,7 +32,6 @@ SEPARATE_GROUPS = [
"vllm/model_executor", "vllm/model_executor",
# v1 related # v1 related
"vllm/v1/kv_offload", "vllm/v1/kv_offload",
"vllm/v1/spec_decode",
] ]
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.

View File

@@ -5,7 +5,6 @@ from typing import Any
import torch import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.speculative import SpeculativeConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
@@ -56,7 +55,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
) )
def _raise_if_padded_drafter_batch_disabled(self): def _raise_if_padded_drafter_batch_disabled(self):
if self.vllm_config.speculative_config.disable_padded_drafter_batch: if self.speculative_config.disable_padded_drafter_batch:
raise NotImplementedError( raise NotImplementedError(
"Speculative Decoding with draft models only supports " "Speculative Decoding with draft models only supports "
"padded drafter batch. Please don't pass --disable-padded-drafter-batch" "padded drafter batch. Please don't pass --disable-padded-drafter-batch"
@@ -64,7 +63,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
) )
def _raise_if_vocab_size_mismatch(self): def _raise_if_vocab_size_mismatch(self):
self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model() self.speculative_config.verify_equal_vocab_size_if_draft_model()
def _raise_if_draft_tp_mismatch(self): def _raise_if_draft_tp_mismatch(self):
# Note(Tomas Ruiz) If we run the target model with TP > 1 and # Note(Tomas Ruiz) If we run the target model with TP > 1 and
@@ -73,7 +72,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
# (because TP=1), then the torch compile cache is overwritten and corrupted. # (because TP=1), then the torch compile cache is overwritten and corrupted.
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414 # We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# To prevent this error, we assert that both TP sizes must be the same. # To prevent this error, we assert that both TP sizes must be the same.
spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config spec_cfg = self.speculative_config
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
if draft_tp != tgt_tp: if draft_tp != tgt_tp:
@@ -190,12 +189,14 @@ def create_vllm_config_for_draft_model(
The vllm_config is useful when loading the draft model with get_model(). The vllm_config is useful when loading the draft model with get_model().
""" """
old = target_model_vllm_config old = target_model_vllm_config
new_parallel_config = old.speculative_config.draft_parallel_config.replace( assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
new_parallel_config = old_spec_config.draft_parallel_config.replace(
rank=old.parallel_config.rank rank=old.parallel_config.rank
) )
new: VllmConfig = old.replace( new: VllmConfig = old.replace(
quant_config=None, # quant_config is recomputed in __init__() quant_config=None, # quant_config is recomputed in __init__()
model_config=old.speculative_config.draft_model_config, model_config=old_spec_config.draft_model_config,
parallel_config=new_parallel_config, parallel_config=new_parallel_config,
) )
return new return new

View File

@@ -3,6 +3,7 @@
import ast import ast
from dataclasses import replace from dataclasses import replace
from importlib.util import find_spec from importlib.util import find_spec
from typing import cast
import numpy as np import numpy as np
import torch import torch
@@ -20,6 +21,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -62,8 +64,8 @@ class SpecDecodeBaseProposer:
runner=None, runner=None,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
assert vllm_config.speculative_config is not None
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.draft_model_config = self.speculative_config.draft_model_config self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method self.method = self.speculative_config.method
self.pass_hidden_states_to_model = pass_hidden_states_to_model self.pass_hidden_states_to_model = pass_hidden_states_to_model
@@ -206,6 +208,7 @@ class SpecDecodeBaseProposer:
# Parse the speculative token tree. # Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree spec_token_tree = self.speculative_config.speculative_token_tree
assert spec_token_tree is not None
self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
tree_depth = len(self.tree_choices[-1]) tree_depth = len(self.tree_choices[-1])
# Precompute per-level properties of the tree. # Precompute per-level properties of the tree.
@@ -1077,9 +1080,12 @@ class SpecDecodeBaseProposer:
return model.__class__.__name__ return model.__class__.__name__
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
draft_model_config = self.vllm_config.speculative_config.draft_model_config draft_model_config = self.speculative_config.draft_model_config
target_attn_layer_names = set( target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
) )
# FIXME: support hybrid kv for draft model # FIXME: support hybrid kv for draft model
target_indexer_layer_names = set( target_indexer_layer_names = set(
@@ -1096,7 +1102,10 @@ class SpecDecodeBaseProposer:
) )
draft_attn_layer_names = ( draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
- target_attn_layer_names - target_attn_layer_names
) )
indexer_layers = get_layers_from_vllm_config( indexer_layers = get_layers_from_vllm_config(
@@ -1136,6 +1145,7 @@ class SpecDecodeBaseProposer:
if supports_multimodal(target_model): if supports_multimodal(target_model):
# handle multimodality # handle multimodality
assert hasattr(target_model, "config")
if self.get_model_name(target_model) in [ if self.get_model_name(target_model) in [
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration", "Qwen3VLForConditionalGeneration",
@@ -1152,16 +1162,21 @@ class SpecDecodeBaseProposer:
self.model.config.image_token_index = ( self.model.config.image_token_index = (
target_model.config.image_token_index target_model.config.image_token_index
) )
target_language_model = target_model.get_language_model() target_language_model = cast(
SupportsMultiModal, target_model
).get_language_model()
else: else:
target_language_model = target_model target_language_model = target_model
# share embed_tokens with the target model if needed # share embed_tokens with the target model if needed
if get_pp_group().world_size == 1: if get_pp_group().world_size == 1:
if hasattr(target_language_model.model, "embed_tokens"): inner_model = getattr(target_language_model, "model", None)
target_embed_tokens = target_language_model.model.embed_tokens if inner_model is None:
elif hasattr(target_language_model.model, "embedding"): raise AttributeError("Target model does not have 'model' attribute")
target_embed_tokens = target_language_model.model.embedding if hasattr(inner_model, "embed_tokens"):
target_embed_tokens = inner_model.embed_tokens
elif hasattr(inner_model, "embedding"):
target_embed_tokens = inner_model.embedding
else: else:
raise AttributeError( raise AttributeError(
"Target model does not have 'embed_tokens' or 'embedding' attribute" "Target model does not have 'embed_tokens' or 'embedding' attribute"

View File

@@ -27,11 +27,13 @@ class MedusaProposer:
): ):
# Save config parameters # Save config parameters
self.vllm_config = vllm_config self.vllm_config = vllm_config
assert vllm_config.speculative_config is not None, (
"Speculative config must be set"
)
self.spec_config = vllm_config.speculative_config
self.device = device self.device = device
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_size = ( self.hidden_size = self.spec_config.draft_model_config.get_hidden_size()
vllm_config.speculative_config.draft_model_config.get_hidden_size()
)
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
def propose( def propose(
@@ -58,7 +60,7 @@ class MedusaProposer:
with set_model_tag("medusa_head"): with set_model_tag("medusa_head"):
self.model = get_model( self.model = get_model(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
model_config=self.vllm_config.speculative_config.draft_model_config, model_config=self.spec_config.draft_model_config,
) )
assert not ( assert not (
is_mixture_of_experts(self.model) is_mixture_of_experts(self.model)

View File

@@ -15,6 +15,7 @@ class SuffixDecodingProposer:
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
config = vllm_config.speculative_config config = vllm_config.speculative_config
assert config is not None, "Speculative config must be set"
self.num_speculative_tokens = config.num_speculative_tokens self.num_speculative_tokens = config.num_speculative_tokens
self.max_tree_depth = config.suffix_decoding_max_tree_depth self.max_tree_depth = config.suffix_decoding_max_tree_depth
self.max_spec_factor = config.suffix_decoding_max_spec_factor self.max_spec_factor = config.suffix_decoding_max_spec_factor