From 726d89720ca7ad4109c30fe7c5ce44456affc365 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Thu, 29 Jan 2026 22:43:32 -0800 Subject: [PATCH] [CI] Enable mypy import following for `vllm/spec_decode` (#33282) Signed-off-by: Lucas Kabela --- tools/pre_commit/mypy.py | 1 - vllm/v1/spec_decode/draft_model.py | 13 +++++----- vllm/v1/spec_decode/eagle.py | 33 +++++++++++++++++++------- vllm/v1/spec_decode/medusa.py | 10 ++++---- vllm/v1/spec_decode/suffix_decoding.py | 1 + 5 files changed, 38 insertions(+), 20 deletions(-) diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 539261cf2..12f6aa327 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -32,7 +32,6 @@ SEPARATE_GROUPS = [ "vllm/model_executor", # v1 related "vllm/v1/kv_offload", - "vllm/v1/spec_decode", ] # TODO(woosuk): Include the code from Megatron and HuggingFace. diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 7d631aa89..9c6754013 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -5,7 +5,6 @@ from typing import Any import torch from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.config.speculative import SpeculativeConfig from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.model_executor.model_loader import get_model @@ -56,7 +55,7 @@ class DraftModelProposer(SpecDecodeBaseProposer): ) def _raise_if_padded_drafter_batch_disabled(self): - if self.vllm_config.speculative_config.disable_padded_drafter_batch: + if self.speculative_config.disable_padded_drafter_batch: raise NotImplementedError( "Speculative Decoding with draft models only supports " "padded drafter batch. Please don't pass --disable-padded-drafter-batch" @@ -64,7 +63,7 @@ class DraftModelProposer(SpecDecodeBaseProposer): ) def _raise_if_vocab_size_mismatch(self): - self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model() + self.speculative_config.verify_equal_vocab_size_if_draft_model() def _raise_if_draft_tp_mismatch(self): # Note(Tomas Ruiz) If we run the target model with TP > 1 and @@ -73,7 +72,7 @@ class DraftModelProposer(SpecDecodeBaseProposer): # (because TP=1), then the torch compile cache is overwritten and corrupted. # We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414 # To prevent this error, we assert that both TP sizes must be the same. - spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config + spec_cfg = self.speculative_config tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size if draft_tp != tgt_tp: @@ -190,12 +189,14 @@ def create_vllm_config_for_draft_model( The vllm_config is useful when loading the draft model with get_model(). """ old = target_model_vllm_config - new_parallel_config = old.speculative_config.draft_parallel_config.replace( + assert old.speculative_config is not None, "speculative_config is not set" + old_spec_config = old.speculative_config + new_parallel_config = old_spec_config.draft_parallel_config.replace( rank=old.parallel_config.rank ) new: VllmConfig = old.replace( quant_config=None, # quant_config is recomputed in __init__() - model_config=old.speculative_config.draft_model_config, + model_config=old_spec_config.draft_model_config, parallel_config=new_parallel_config, ) return new diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2ebb0ba43..f157f529c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,6 +3,7 @@ import ast from dataclasses import replace from importlib.util import find_spec +from typing import cast import numpy as np import torch @@ -20,6 +21,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache +from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform @@ -62,8 +64,8 @@ class SpecDecodeBaseProposer: runner=None, ): self.vllm_config = vllm_config + assert vllm_config.speculative_config is not None self.speculative_config = vllm_config.speculative_config - assert self.speculative_config is not None self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method self.pass_hidden_states_to_model = pass_hidden_states_to_model @@ -206,6 +208,7 @@ class SpecDecodeBaseProposer: # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree + assert spec_token_tree is not None self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) tree_depth = len(self.tree_choices[-1]) # Precompute per-level properties of the tree. @@ -1077,9 +1080,12 @@ class SpecDecodeBaseProposer: return model.__class__.__name__ def load_model(self, target_model: nn.Module) -> None: - draft_model_config = self.vllm_config.speculative_config.draft_model_config + draft_model_config = self.speculative_config.draft_model_config target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() + get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ).keys() ) # FIXME: support hybrid kv for draft model target_indexer_layer_names = set( @@ -1096,7 +1102,10 @@ class SpecDecodeBaseProposer: ) draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() + get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ).keys() - target_attn_layer_names ) indexer_layers = get_layers_from_vllm_config( @@ -1136,6 +1145,7 @@ class SpecDecodeBaseProposer: if supports_multimodal(target_model): # handle multimodality + assert hasattr(target_model, "config") if self.get_model_name(target_model) in [ "Qwen2_5_VLForConditionalGeneration", "Qwen3VLForConditionalGeneration", @@ -1152,16 +1162,21 @@ class SpecDecodeBaseProposer: self.model.config.image_token_index = ( target_model.config.image_token_index ) - target_language_model = target_model.get_language_model() + target_language_model = cast( + SupportsMultiModal, target_model + ).get_language_model() else: target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: - if hasattr(target_language_model.model, "embed_tokens"): - target_embed_tokens = target_language_model.model.embed_tokens - elif hasattr(target_language_model.model, "embedding"): - target_embed_tokens = target_language_model.model.embedding + inner_model = getattr(target_language_model, "model", None) + if inner_model is None: + raise AttributeError("Target model does not have 'model' attribute") + if hasattr(inner_model, "embed_tokens"): + target_embed_tokens = inner_model.embed_tokens + elif hasattr(inner_model, "embedding"): + target_embed_tokens = inner_model.embedding else: raise AttributeError( "Target model does not have 'embed_tokens' or 'embedding' attribute" diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 2e9330bf6..80b0f0a98 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -27,11 +27,13 @@ class MedusaProposer: ): # Save config parameters self.vllm_config = vllm_config + assert vllm_config.speculative_config is not None, ( + "Speculative config must be set" + ) + self.spec_config = vllm_config.speculative_config self.device = device self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self.hidden_size = ( - vllm_config.speculative_config.draft_model_config.get_hidden_size() - ) + self.hidden_size = self.spec_config.draft_model_config.get_hidden_size() self.dtype = vllm_config.model_config.dtype def propose( @@ -58,7 +60,7 @@ class MedusaProposer: with set_model_tag("medusa_head"): self.model = get_model( vllm_config=self.vllm_config, - model_config=self.vllm_config.speculative_config.draft_model_config, + model_config=self.spec_config.draft_model_config, ) assert not ( is_mixture_of_experts(self.model) diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index c5f8e6f86..fee5d9746 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -15,6 +15,7 @@ class SuffixDecodingProposer: def __init__(self, vllm_config: VllmConfig): config = vllm_config.speculative_config + assert config is not None, "Speculative config must be set" self.num_speculative_tokens = config.num_speculative_tokens self.max_tree_depth = config.suffix_decoding_max_tree_depth self.max_spec_factor = config.suffix_decoding_max_spec_factor