[CI] Enable mypy import following for vllm/spec_decode (#33282)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -32,7 +32,6 @@ SEPARATE_GROUPS = [
|
|||||||
"vllm/model_executor",
|
"vllm/model_executor",
|
||||||
# v1 related
|
# v1 related
|
||||||
"vllm/v1/kv_offload",
|
"vllm/v1/kv_offload",
|
||||||
"vllm/v1/spec_decode",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import Any
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.config.speculative import SpeculativeConfig
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.attention import Attention
|
from vllm.model_executor.layers.attention import Attention
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
@@ -56,7 +55,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _raise_if_padded_drafter_batch_disabled(self):
|
def _raise_if_padded_drafter_batch_disabled(self):
|
||||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
if self.speculative_config.disable_padded_drafter_batch:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Speculative Decoding with draft models only supports "
|
"Speculative Decoding with draft models only supports "
|
||||||
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
|
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
|
||||||
@@ -64,7 +63,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _raise_if_vocab_size_mismatch(self):
|
def _raise_if_vocab_size_mismatch(self):
|
||||||
self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model()
|
self.speculative_config.verify_equal_vocab_size_if_draft_model()
|
||||||
|
|
||||||
def _raise_if_draft_tp_mismatch(self):
|
def _raise_if_draft_tp_mismatch(self):
|
||||||
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
|
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
|
||||||
@@ -73,7 +72,7 @@ class DraftModelProposer(SpecDecodeBaseProposer):
|
|||||||
# (because TP=1), then the torch compile cache is overwritten and corrupted.
|
# (because TP=1), then the torch compile cache is overwritten and corrupted.
|
||||||
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
|
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
|
||||||
# To prevent this error, we assert that both TP sizes must be the same.
|
# To prevent this error, we assert that both TP sizes must be the same.
|
||||||
spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config
|
spec_cfg = self.speculative_config
|
||||||
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
|
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
|
||||||
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
|
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
|
||||||
if draft_tp != tgt_tp:
|
if draft_tp != tgt_tp:
|
||||||
@@ -190,12 +189,14 @@ def create_vllm_config_for_draft_model(
|
|||||||
The vllm_config is useful when loading the draft model with get_model().
|
The vllm_config is useful when loading the draft model with get_model().
|
||||||
"""
|
"""
|
||||||
old = target_model_vllm_config
|
old = target_model_vllm_config
|
||||||
new_parallel_config = old.speculative_config.draft_parallel_config.replace(
|
assert old.speculative_config is not None, "speculative_config is not set"
|
||||||
|
old_spec_config = old.speculative_config
|
||||||
|
new_parallel_config = old_spec_config.draft_parallel_config.replace(
|
||||||
rank=old.parallel_config.rank
|
rank=old.parallel_config.rank
|
||||||
)
|
)
|
||||||
new: VllmConfig = old.replace(
|
new: VllmConfig = old.replace(
|
||||||
quant_config=None, # quant_config is recomputed in __init__()
|
quant_config=None, # quant_config is recomputed in __init__()
|
||||||
model_config=old.speculative_config.draft_model_config,
|
model_config=old_spec_config.draft_model_config,
|
||||||
parallel_config=new_parallel_config,
|
parallel_config=new_parallel_config,
|
||||||
)
|
)
|
||||||
return new
|
return new
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import ast
|
import ast
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -20,6 +21,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import supports_multimodal
|
from vllm.model_executor.models import supports_multimodal
|
||||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -62,8 +64,8 @@ class SpecDecodeBaseProposer:
|
|||||||
runner=None,
|
runner=None,
|
||||||
):
|
):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
assert vllm_config.speculative_config is not None
|
||||||
self.speculative_config = vllm_config.speculative_config
|
self.speculative_config = vllm_config.speculative_config
|
||||||
assert self.speculative_config is not None
|
|
||||||
self.draft_model_config = self.speculative_config.draft_model_config
|
self.draft_model_config = self.speculative_config.draft_model_config
|
||||||
self.method = self.speculative_config.method
|
self.method = self.speculative_config.method
|
||||||
self.pass_hidden_states_to_model = pass_hidden_states_to_model
|
self.pass_hidden_states_to_model = pass_hidden_states_to_model
|
||||||
@@ -206,6 +208,7 @@ class SpecDecodeBaseProposer:
|
|||||||
|
|
||||||
# Parse the speculative token tree.
|
# Parse the speculative token tree.
|
||||||
spec_token_tree = self.speculative_config.speculative_token_tree
|
spec_token_tree = self.speculative_config.speculative_token_tree
|
||||||
|
assert spec_token_tree is not None
|
||||||
self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
|
self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
|
||||||
tree_depth = len(self.tree_choices[-1])
|
tree_depth = len(self.tree_choices[-1])
|
||||||
# Precompute per-level properties of the tree.
|
# Precompute per-level properties of the tree.
|
||||||
@@ -1077,9 +1080,12 @@ class SpecDecodeBaseProposer:
|
|||||||
return model.__class__.__name__
|
return model.__class__.__name__
|
||||||
|
|
||||||
def load_model(self, target_model: nn.Module) -> None:
|
def load_model(self, target_model: nn.Module) -> None:
|
||||||
draft_model_config = self.vllm_config.speculative_config.draft_model_config
|
draft_model_config = self.speculative_config.draft_model_config
|
||||||
target_attn_layer_names = set(
|
target_attn_layer_names = set(
|
||||||
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
|
get_layers_from_vllm_config(
|
||||||
|
self.vllm_config,
|
||||||
|
AttentionLayerBase, # type: ignore[type-abstract]
|
||||||
|
).keys()
|
||||||
)
|
)
|
||||||
# FIXME: support hybrid kv for draft model
|
# FIXME: support hybrid kv for draft model
|
||||||
target_indexer_layer_names = set(
|
target_indexer_layer_names = set(
|
||||||
@@ -1096,7 +1102,10 @@ class SpecDecodeBaseProposer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
draft_attn_layer_names = (
|
draft_attn_layer_names = (
|
||||||
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
|
get_layers_from_vllm_config(
|
||||||
|
self.vllm_config,
|
||||||
|
AttentionLayerBase, # type: ignore[type-abstract]
|
||||||
|
).keys()
|
||||||
- target_attn_layer_names
|
- target_attn_layer_names
|
||||||
)
|
)
|
||||||
indexer_layers = get_layers_from_vllm_config(
|
indexer_layers = get_layers_from_vllm_config(
|
||||||
@@ -1136,6 +1145,7 @@ class SpecDecodeBaseProposer:
|
|||||||
|
|
||||||
if supports_multimodal(target_model):
|
if supports_multimodal(target_model):
|
||||||
# handle multimodality
|
# handle multimodality
|
||||||
|
assert hasattr(target_model, "config")
|
||||||
if self.get_model_name(target_model) in [
|
if self.get_model_name(target_model) in [
|
||||||
"Qwen2_5_VLForConditionalGeneration",
|
"Qwen2_5_VLForConditionalGeneration",
|
||||||
"Qwen3VLForConditionalGeneration",
|
"Qwen3VLForConditionalGeneration",
|
||||||
@@ -1152,16 +1162,21 @@ class SpecDecodeBaseProposer:
|
|||||||
self.model.config.image_token_index = (
|
self.model.config.image_token_index = (
|
||||||
target_model.config.image_token_index
|
target_model.config.image_token_index
|
||||||
)
|
)
|
||||||
target_language_model = target_model.get_language_model()
|
target_language_model = cast(
|
||||||
|
SupportsMultiModal, target_model
|
||||||
|
).get_language_model()
|
||||||
else:
|
else:
|
||||||
target_language_model = target_model
|
target_language_model = target_model
|
||||||
|
|
||||||
# share embed_tokens with the target model if needed
|
# share embed_tokens with the target model if needed
|
||||||
if get_pp_group().world_size == 1:
|
if get_pp_group().world_size == 1:
|
||||||
if hasattr(target_language_model.model, "embed_tokens"):
|
inner_model = getattr(target_language_model, "model", None)
|
||||||
target_embed_tokens = target_language_model.model.embed_tokens
|
if inner_model is None:
|
||||||
elif hasattr(target_language_model.model, "embedding"):
|
raise AttributeError("Target model does not have 'model' attribute")
|
||||||
target_embed_tokens = target_language_model.model.embedding
|
if hasattr(inner_model, "embed_tokens"):
|
||||||
|
target_embed_tokens = inner_model.embed_tokens
|
||||||
|
elif hasattr(inner_model, "embedding"):
|
||||||
|
target_embed_tokens = inner_model.embedding
|
||||||
else:
|
else:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"Target model does not have 'embed_tokens' or 'embedding' attribute"
|
"Target model does not have 'embed_tokens' or 'embedding' attribute"
|
||||||
|
|||||||
@@ -27,11 +27,13 @@ class MedusaProposer:
|
|||||||
):
|
):
|
||||||
# Save config parameters
|
# Save config parameters
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
assert vllm_config.speculative_config is not None, (
|
||||||
|
"Speculative config must be set"
|
||||||
|
)
|
||||||
|
self.spec_config = vllm_config.speculative_config
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||||
self.hidden_size = (
|
self.hidden_size = self.spec_config.draft_model_config.get_hidden_size()
|
||||||
vllm_config.speculative_config.draft_model_config.get_hidden_size()
|
|
||||||
)
|
|
||||||
self.dtype = vllm_config.model_config.dtype
|
self.dtype = vllm_config.model_config.dtype
|
||||||
|
|
||||||
def propose(
|
def propose(
|
||||||
@@ -58,7 +60,7 @@ class MedusaProposer:
|
|||||||
with set_model_tag("medusa_head"):
|
with set_model_tag("medusa_head"):
|
||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
model_config=self.vllm_config.speculative_config.draft_model_config,
|
model_config=self.spec_config.draft_model_config,
|
||||||
)
|
)
|
||||||
assert not (
|
assert not (
|
||||||
is_mixture_of_experts(self.model)
|
is_mixture_of_experts(self.model)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class SuffixDecodingProposer:
|
|||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
config = vllm_config.speculative_config
|
config = vllm_config.speculative_config
|
||||||
|
assert config is not None, "Speculative config must be set"
|
||||||
self.num_speculative_tokens = config.num_speculative_tokens
|
self.num_speculative_tokens = config.num_speculative_tokens
|
||||||
self.max_tree_depth = config.suffix_decoding_max_tree_depth
|
self.max_tree_depth = config.suffix_decoding_max_tree_depth
|
||||||
self.max_spec_factor = config.suffix_decoding_max_spec_factor
|
self.max_spec_factor = config.suffix_decoding_max_spec_factor
|
||||||
|
|||||||
Reference in New Issue
Block a user