[CI] Enable mypy import following for vllm/spec_decode (#33282)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-01-29 22:43:32 -08:00
committed by GitHub
parent d334dd26c4
commit 726d89720c
5 changed files with 38 additions and 20 deletions

View File

@@ -32,7 +32,6 @@ SEPARATE_GROUPS = [
"vllm/model_executor",
# v1 related
"vllm/v1/kv_offload",
"vllm/v1/spec_decode",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.

View File

@@ -5,7 +5,6 @@ from typing import Any
import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.speculative import SpeculativeConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.model_loader import get_model
@@ -56,7 +55,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
)
def _raise_if_padded_drafter_batch_disabled(self):
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
if self.speculative_config.disable_padded_drafter_batch:
raise NotImplementedError(
"Speculative Decoding with draft models only supports "
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
@@ -64,7 +63,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
)
def _raise_if_vocab_size_mismatch(self):
self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model()
self.speculative_config.verify_equal_vocab_size_if_draft_model()
def _raise_if_draft_tp_mismatch(self):
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
@@ -73,7 +72,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
# (because TP=1), then the torch compile cache is overwritten and corrupted.
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# To prevent this error, we assert that both TP sizes must be the same.
spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config
spec_cfg = self.speculative_config
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
if draft_tp != tgt_tp:
@@ -190,12 +189,14 @@ def create_vllm_config_for_draft_model(
The vllm_config is useful when loading the draft model with get_model().
"""
old = target_model_vllm_config
new_parallel_config = old.speculative_config.draft_parallel_config.replace(
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
new_parallel_config = old_spec_config.draft_parallel_config.replace(
rank=old.parallel_config.rank
)
new: VllmConfig = old.replace(
quant_config=None, # quant_config is recomputed in __init__()
model_config=old.speculative_config.draft_model_config,
model_config=old_spec_config.draft_model_config,
parallel_config=new_parallel_config,
)
return new

View File

@@ -3,6 +3,7 @@
import ast
from dataclasses import replace
from importlib.util import find_spec
from typing import cast
import numpy as np
import torch
@@ -20,6 +21,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
@@ -62,8 +64,8 @@ class SpecDecodeBaseProposer:
runner=None,
):
self.vllm_config = vllm_config
assert vllm_config.speculative_config is not None
self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.pass_hidden_states_to_model = pass_hidden_states_to_model
@@ -206,6 +208,7 @@ class SpecDecodeBaseProposer:
# Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree
assert spec_token_tree is not None
self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
tree_depth = len(self.tree_choices[-1])
# Precompute per-level properties of the tree.
@@ -1077,9 +1080,12 @@ class SpecDecodeBaseProposer:
return model.__class__.__name__
def load_model(self, target_model: nn.Module) -> None:
draft_model_config = self.vllm_config.speculative_config.draft_model_config
draft_model_config = self.speculative_config.draft_model_config
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
)
# FIXME: support hybrid kv for draft model
target_indexer_layer_names = set(
@@ -1096,7 +1102,10 @@ class SpecDecodeBaseProposer:
)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
- target_attn_layer_names
)
indexer_layers = get_layers_from_vllm_config(
@@ -1136,6 +1145,7 @@ class SpecDecodeBaseProposer:
if supports_multimodal(target_model):
# handle multimodality
assert hasattr(target_model, "config")
if self.get_model_name(target_model) in [
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
@@ -1152,16 +1162,21 @@ class SpecDecodeBaseProposer:
self.model.config.image_token_index = (
target_model.config.image_token_index
)
target_language_model = target_model.get_language_model()
target_language_model = cast(
SupportsMultiModal, target_model
).get_language_model()
else:
target_language_model = target_model
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
if hasattr(target_language_model.model, "embed_tokens"):
target_embed_tokens = target_language_model.model.embed_tokens
elif hasattr(target_language_model.model, "embedding"):
target_embed_tokens = target_language_model.model.embedding
inner_model = getattr(target_language_model, "model", None)
if inner_model is None:
raise AttributeError("Target model does not have 'model' attribute")
if hasattr(inner_model, "embed_tokens"):
target_embed_tokens = inner_model.embed_tokens
elif hasattr(inner_model, "embedding"):
target_embed_tokens = inner_model.embedding
else:
raise AttributeError(
"Target model does not have 'embed_tokens' or 'embedding' attribute"

View File

@@ -27,11 +27,13 @@ class MedusaProposer:
):
# Save config parameters
self.vllm_config = vllm_config
assert vllm_config.speculative_config is not None, (
"Speculative config must be set"
)
self.spec_config = vllm_config.speculative_config
self.device = device
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_size = (
vllm_config.speculative_config.draft_model_config.get_hidden_size()
)
self.hidden_size = self.spec_config.draft_model_config.get_hidden_size()
self.dtype = vllm_config.model_config.dtype
def propose(
@@ -58,7 +60,7 @@ class MedusaProposer:
with set_model_tag("medusa_head"):
self.model = get_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.speculative_config.draft_model_config,
model_config=self.spec_config.draft_model_config,
)
assert not (
is_mixture_of_experts(self.model)

View File

@@ -15,6 +15,7 @@ class SuffixDecodingProposer:
def __init__(self, vllm_config: VllmConfig):
config = vllm_config.speculative_config
assert config is not None, "Speculative config must be set"
self.num_speculative_tokens = config.num_speculative_tokens
self.max_tree_depth = config.suffix_decoding_max_tree_depth
self.max_spec_factor = config.suffix_decoding_max_spec_factor