[CI] Enable mypy import following for vllm/spec_decode (#33282)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -32,7 +32,6 @@ SEPARATE_GROUPS = [
|
||||
"vllm/model_executor",
|
||||
# v1 related
|
||||
"vllm/v1/kv_offload",
|
||||
"vllm/v1/spec_decode",
|
||||
]
|
||||
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@@ -56,7 +55,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
|
||||
)
|
||||
|
||||
def _raise_if_padded_drafter_batch_disabled(self):
|
||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
raise NotImplementedError(
|
||||
"Speculative Decoding with draft models only supports "
|
||||
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
|
||||
@@ -64,7 +63,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
|
||||
)
|
||||
|
||||
def _raise_if_vocab_size_mismatch(self):
|
||||
self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model()
|
||||
self.speculative_config.verify_equal_vocab_size_if_draft_model()
|
||||
|
||||
def _raise_if_draft_tp_mismatch(self):
|
||||
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
|
||||
@@ -73,7 +72,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
|
||||
# (because TP=1), then the torch compile cache is overwritten and corrupted.
|
||||
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
|
||||
# To prevent this error, we assert that both TP sizes must be the same.
|
||||
spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config
|
||||
spec_cfg = self.speculative_config
|
||||
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
|
||||
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
|
||||
if draft_tp != tgt_tp:
|
||||
@@ -190,12 +189,14 @@ def create_vllm_config_for_draft_model(
|
||||
The vllm_config is useful when loading the draft model with get_model().
|
||||
"""
|
||||
old = target_model_vllm_config
|
||||
new_parallel_config = old.speculative_config.draft_parallel_config.replace(
|
||||
assert old.speculative_config is not None, "speculative_config is not set"
|
||||
old_spec_config = old.speculative_config
|
||||
new_parallel_config = old_spec_config.draft_parallel_config.replace(
|
||||
rank=old.parallel_config.rank
|
||||
)
|
||||
new: VllmConfig = old.replace(
|
||||
quant_config=None, # quant_config is recomputed in __init__()
|
||||
model_config=old.speculative_config.draft_model_config,
|
||||
model_config=old_spec_config.draft_model_config,
|
||||
parallel_config=new_parallel_config,
|
||||
)
|
||||
return new
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import ast
|
||||
from dataclasses import replace
|
||||
from importlib.util import find_spec
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -20,6 +21,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
@@ -62,8 +64,8 @@ class SpecDecodeBaseProposer:
|
||||
runner=None,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
assert vllm_config.speculative_config is not None
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
assert self.speculative_config is not None
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
self.pass_hidden_states_to_model = pass_hidden_states_to_model
|
||||
@@ -206,6 +208,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
# Parse the speculative token tree.
|
||||
spec_token_tree = self.speculative_config.speculative_token_tree
|
||||
assert spec_token_tree is not None
|
||||
self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
|
||||
tree_depth = len(self.tree_choices[-1])
|
||||
# Precompute per-level properties of the tree.
|
||||
@@ -1077,9 +1080,12 @@ class SpecDecodeBaseProposer:
|
||||
return model.__class__.__name__
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
draft_model_config = self.vllm_config.speculative_config.draft_model_config
|
||||
draft_model_config = self.speculative_config.draft_model_config
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
|
||||
get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
)
|
||||
# FIXME: support hybrid kv for draft model
|
||||
target_indexer_layer_names = set(
|
||||
@@ -1096,7 +1102,10 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
|
||||
get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
- target_attn_layer_names
|
||||
)
|
||||
indexer_layers = get_layers_from_vllm_config(
|
||||
@@ -1136,6 +1145,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
if supports_multimodal(target_model):
|
||||
# handle multimodality
|
||||
assert hasattr(target_model, "config")
|
||||
if self.get_model_name(target_model) in [
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
@@ -1152,16 +1162,21 @@ class SpecDecodeBaseProposer:
|
||||
self.model.config.image_token_index = (
|
||||
target_model.config.image_token_index
|
||||
)
|
||||
target_language_model = target_model.get_language_model()
|
||||
target_language_model = cast(
|
||||
SupportsMultiModal, target_model
|
||||
).get_language_model()
|
||||
else:
|
||||
target_language_model = target_model
|
||||
|
||||
# share embed_tokens with the target model if needed
|
||||
if get_pp_group().world_size == 1:
|
||||
if hasattr(target_language_model.model, "embed_tokens"):
|
||||
target_embed_tokens = target_language_model.model.embed_tokens
|
||||
elif hasattr(target_language_model.model, "embedding"):
|
||||
target_embed_tokens = target_language_model.model.embedding
|
||||
inner_model = getattr(target_language_model, "model", None)
|
||||
if inner_model is None:
|
||||
raise AttributeError("Target model does not have 'model' attribute")
|
||||
if hasattr(inner_model, "embed_tokens"):
|
||||
target_embed_tokens = inner_model.embed_tokens
|
||||
elif hasattr(inner_model, "embedding"):
|
||||
target_embed_tokens = inner_model.embedding
|
||||
else:
|
||||
raise AttributeError(
|
||||
"Target model does not have 'embed_tokens' or 'embedding' attribute"
|
||||
|
||||
@@ -27,11 +27,13 @@ class MedusaProposer:
|
||||
):
|
||||
# Save config parameters
|
||||
self.vllm_config = vllm_config
|
||||
assert vllm_config.speculative_config is not None, (
|
||||
"Speculative config must be set"
|
||||
)
|
||||
self.spec_config = vllm_config.speculative_config
|
||||
self.device = device
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.hidden_size = (
|
||||
vllm_config.speculative_config.draft_model_config.get_hidden_size()
|
||||
)
|
||||
self.hidden_size = self.spec_config.draft_model_config.get_hidden_size()
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
def propose(
|
||||
@@ -58,7 +60,7 @@ class MedusaProposer:
|
||||
with set_model_tag("medusa_head"):
|
||||
self.model = get_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.speculative_config.draft_model_config,
|
||||
model_config=self.spec_config.draft_model_config,
|
||||
)
|
||||
assert not (
|
||||
is_mixture_of_experts(self.model)
|
||||
|
||||
@@ -15,6 +15,7 @@ class SuffixDecodingProposer:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
config = vllm_config.speculative_config
|
||||
assert config is not None, "Speculative config must be set"
|
||||
self.num_speculative_tokens = config.num_speculative_tokens
|
||||
self.max_tree_depth = config.suffix_decoding_max_tree_depth
|
||||
self.max_spec_factor = config.suffix_decoding_max_spec_factor
|
||||
|
||||
Reference in New Issue
Block a user