[Feature] Support per-draft-model MoE backend via --speculative-config (#37880)
Signed-off-by: Andrii Skliar <askliar@nvidia.com> Signed-off-by: [Andrii Skliar] <askliar@nvidia.com> Co-authored-by: Andrii Skliar <askliar@nvidia.com>
This commit is contained in:
@@ -20,12 +20,11 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.benchmarks.datasets import InstructCoderDataset
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, replace
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
|
||||
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
@@ -919,13 +918,104 @@ def test_draft_model_engine_args_tensor_parallelism():
|
||||
"draft_tensor_parallel_size": 1, # <<< valid arg name
|
||||
},
|
||||
)
|
||||
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
|
||||
assert tgt_vllm_config.quant_config.get_name() == "fp8"
|
||||
target_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert target_config.parallel_config.tensor_parallel_size == 2
|
||||
assert target_config.quant_config.get_name() == "fp8"
|
||||
|
||||
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
|
||||
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
|
||||
assert draft_vllm_config.quant_config is None
|
||||
speculative_config = target_config.speculative_config
|
||||
draft_config: VllmConfig = replace(
|
||||
target_config,
|
||||
quant_config=None,
|
||||
parallel_config=replace(
|
||||
speculative_config.draft_parallel_config,
|
||||
rank=target_config.parallel_config.rank,
|
||||
),
|
||||
model_config=speculative_config.draft_model_config,
|
||||
)
|
||||
assert draft_config.parallel_config.tensor_parallel_size == 1
|
||||
assert draft_config.quant_config is None
|
||||
|
||||
|
||||
def _apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig:
|
||||
"""Replicate SpecDecodeBaseProposer._create_draft_vllm_config logic
|
||||
so we can test it without instantiating a full proposer."""
|
||||
spec_cfg = vllm_config.speculative_config
|
||||
if spec_cfg.moe_backend is not None:
|
||||
return replace(
|
||||
vllm_config,
|
||||
kernel_config=replace(
|
||||
vllm_config.kernel_config,
|
||||
moe_backend=spec_cfg.moe_backend,
|
||||
),
|
||||
)
|
||||
return vllm_config
|
||||
|
||||
|
||||
def test_draft_model_moe_backend_override():
|
||||
"""When moe_backend is set in speculative_config, the draft VllmConfig
|
||||
should use it while the target keeps its own setting."""
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-1.7B",
|
||||
tensor_parallel_size=1,
|
||||
moe_backend="flashinfer_trtllm",
|
||||
speculative_config={
|
||||
"model": "Qwen/Qwen3-0.6B",
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": 3,
|
||||
"moe_backend": "triton",
|
||||
},
|
||||
)
|
||||
tgt_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
|
||||
assert tgt_config.speculative_config.moe_backend == "triton"
|
||||
|
||||
draft_config = _apply_draft_moe_backend(tgt_config)
|
||||
assert draft_config.kernel_config.moe_backend == "triton"
|
||||
# Target config must be unaffected.
|
||||
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
|
||||
|
||||
|
||||
def test_draft_model_moe_backend_inherits_target():
|
||||
"""When moe_backend is not set in speculative_config, the draft should
|
||||
inherit the target's moe_backend."""
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-1.7B",
|
||||
tensor_parallel_size=1,
|
||||
moe_backend="flashinfer_cutlass",
|
||||
speculative_config={
|
||||
"model": "Qwen/Qwen3-0.6B",
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
)
|
||||
tgt_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert tgt_config.kernel_config.moe_backend == "flashinfer_cutlass"
|
||||
assert tgt_config.speculative_config.moe_backend is None
|
||||
|
||||
draft_config = _apply_draft_moe_backend(tgt_config)
|
||||
assert draft_config.kernel_config.moe_backend == "flashinfer_cutlass"
|
||||
assert draft_config is tgt_config
|
||||
|
||||
|
||||
def test_draft_model_moe_backend_default_auto():
|
||||
"""When neither target nor draft set moe_backend explicitly, both should
|
||||
default to 'auto'."""
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-1.7B",
|
||||
tensor_parallel_size=1,
|
||||
speculative_config={
|
||||
"model": "Qwen/Qwen3-0.6B",
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
)
|
||||
tgt_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert tgt_config.kernel_config.moe_backend == "auto"
|
||||
assert tgt_config.speculative_config.moe_backend is None
|
||||
|
||||
draft_config = _apply_draft_moe_backend(tgt_config)
|
||||
assert draft_config.kernel_config.moe_backend == "auto"
|
||||
assert draft_config is tgt_config
|
||||
|
||||
|
||||
def test_draft_model_engine_args_rejects_invalid_tp_argname():
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import Field, SkipValidation, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config import LoadConfig
|
||||
from vllm.config.kernel import MoEBackend
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.utils import config
|
||||
@@ -93,6 +94,11 @@ class SpeculativeConfig:
|
||||
"""Quantization method that was used to quantize the draft model weights.
|
||||
If `None`, we assume the model weights are not quantized. Note that it only
|
||||
takes effect when using the draft model-based speculative method."""
|
||||
moe_backend: MoEBackend | None = None
|
||||
"""MoE backend to use for the draft model. When `None`, the draft model
|
||||
inherits the target model's `--moe-backend` setting. Useful when the
|
||||
drafter and generator require different MoE kernels (e.g. quantized
|
||||
generator with unquantized drafter)."""
|
||||
max_model_len: int | None = Field(default=None, ge=1)
|
||||
"""The maximum model length of the draft model. Used when testing the
|
||||
ability to skip speculation for some sequences."""
|
||||
|
||||
@@ -6,10 +6,10 @@ import torch.nn as nn
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import replace
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
|
||||
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -50,16 +50,29 @@ class DraftModelProposer(SpecDecodeBaseProposer):
|
||||
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
|
||||
)
|
||||
|
||||
@override
|
||||
def _create_draft_vllm_config(self) -> VllmConfig:
|
||||
base = super()._create_draft_vllm_config()
|
||||
spec = self.speculative_config
|
||||
|
||||
return replace(
|
||||
base,
|
||||
quant_config=None,
|
||||
parallel_config=replace(
|
||||
spec.draft_parallel_config,
|
||||
rank=self.vllm_config.parallel_config.rank,
|
||||
),
|
||||
model_config=spec.draft_model_config,
|
||||
)
|
||||
|
||||
@override
|
||||
def _get_model(self) -> nn.Module:
|
||||
# Draft models may be quantized or on different parallelism,
|
||||
# so we load them with a modified vllm config
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config)
|
||||
draft_vllm_config = self._create_draft_vllm_config()
|
||||
with set_model_tag("draft_model"):
|
||||
model = get_model(
|
||||
vllm_config=temp_vllm_config,
|
||||
vllm_config=draft_vllm_config,
|
||||
prefix="draft_model",
|
||||
)
|
||||
return model
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
from dataclasses import replace
|
||||
from importlib.util import find_spec
|
||||
from typing import cast
|
||||
|
||||
@@ -13,6 +12,7 @@ from vllm.config import (
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
get_layers_from_vllm_config,
|
||||
replace,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
@@ -1213,6 +1213,21 @@ class SpecDecodeBaseProposer:
|
||||
model = model.module
|
||||
return model.__class__.__name__
|
||||
|
||||
def _create_draft_vllm_config(self) -> VllmConfig:
|
||||
"""Return a VllmConfig with kernel-level overrides for the proposer.
|
||||
Subclasses may override to apply additional config changes.
|
||||
"""
|
||||
spec_cfg = self.speculative_config
|
||||
if spec_cfg.moe_backend is not None:
|
||||
return replace(
|
||||
self.vllm_config,
|
||||
kernel_config=replace(
|
||||
self.vllm_config.kernel_config,
|
||||
moe_backend=spec_cfg.moe_backend,
|
||||
),
|
||||
)
|
||||
return self.vllm_config
|
||||
|
||||
def _get_model(self) -> nn.Module:
|
||||
"""
|
||||
Default method to call get_model(). Can be overridden by subclasses which
|
||||
@@ -1220,9 +1235,10 @@ class SpecDecodeBaseProposer:
|
||||
"""
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
draft_vllm_config = self._create_draft_vllm_config()
|
||||
with set_model_tag("eagle_head"):
|
||||
model = get_model(
|
||||
vllm_config=self.vllm_config,
|
||||
vllm_config=draft_vllm_config,
|
||||
model_config=self.speculative_config.draft_model_config,
|
||||
load_config=self.speculative_config.draft_load_config,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, replace
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
@@ -258,30 +257,6 @@ def compute_new_slot_mapping(
|
||||
return new_slot_mapping
|
||||
|
||||
|
||||
def create_vllm_config_for_draft_model(
|
||||
target_model_vllm_config: VllmConfig,
|
||||
) -> VllmConfig:
|
||||
"""The vllm_config is configured for the target model, e.g.
|
||||
its quant_config and parallel_config. But the draft model is potentially
|
||||
quantized differently, and has potentially different tensor_parallel_size.
|
||||
This function creates a new vllm_config configured for the drafter.
|
||||
The vllm_config is useful when loading the draft model with get_model().
|
||||
"""
|
||||
old = target_model_vllm_config
|
||||
assert old.speculative_config is not None, "speculative_config is not set"
|
||||
old_spec_config = old.speculative_config
|
||||
new_parallel_config = replace(
|
||||
old_spec_config.draft_parallel_config, rank=old.parallel_config.rank
|
||||
)
|
||||
new: VllmConfig = replace(
|
||||
old,
|
||||
quant_config=None,
|
||||
parallel_config=new_parallel_config,
|
||||
model_config=old_spec_config.draft_model_config,
|
||||
)
|
||||
return new
|
||||
|
||||
|
||||
def extend_all_queries_by_N(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
N: int,
|
||||
|
||||
Reference in New Issue
Block a user