[Feature] Support per-draft-model MoE backend via --speculative-config (#37880)

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: [Andrii Skliar] <askliar@nvidia.com>
Co-authored-by: Andrii Skliar <askliar@nvidia.com>
This commit is contained in:
Andrii Skliar
2026-03-25 15:31:52 +01:00
committed by GitHub
parent a1a2566447
commit cd7643015e
5 changed files with 140 additions and 40 deletions

View File

@@ -20,12 +20,11 @@ from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.benchmarks.datasets import InstructCoderDataset
from vllm.config import VllmConfig
from vllm.config import VllmConfig, replace
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
MTP_SIMILARITY_RATE = 0.8
@@ -919,13 +918,104 @@ def test_draft_model_engine_args_tensor_parallelism():
"draft_tensor_parallel_size": 1, # <<< valid arg name
},
)
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
assert tgt_vllm_config.quant_config.get_name() == "fp8"
target_config: VllmConfig = engine_args.create_engine_config()
assert target_config.parallel_config.tensor_parallel_size == 2
assert target_config.quant_config.get_name() == "fp8"
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
assert draft_vllm_config.quant_config is None
speculative_config = target_config.speculative_config
draft_config: VllmConfig = replace(
target_config,
quant_config=None,
parallel_config=replace(
speculative_config.draft_parallel_config,
rank=target_config.parallel_config.rank,
),
model_config=speculative_config.draft_model_config,
)
assert draft_config.parallel_config.tensor_parallel_size == 1
assert draft_config.quant_config is None
def _apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig:
"""Replicate SpecDecodeBaseProposer._create_draft_vllm_config logic
so we can test it without instantiating a full proposer."""
spec_cfg = vllm_config.speculative_config
if spec_cfg.moe_backend is not None:
return replace(
vllm_config,
kernel_config=replace(
vllm_config.kernel_config,
moe_backend=spec_cfg.moe_backend,
),
)
return vllm_config
def test_draft_model_moe_backend_override():
"""When moe_backend is set in speculative_config, the draft VllmConfig
should use it while the target keeps its own setting."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
moe_backend="flashinfer_trtllm",
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
"moe_backend": "triton",
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
assert tgt_config.speculative_config.moe_backend == "triton"
draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "triton"
# Target config must be unaffected.
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
def test_draft_model_moe_backend_inherits_target():
"""When moe_backend is not set in speculative_config, the draft should
inherit the target's moe_backend."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
moe_backend="flashinfer_cutlass",
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "flashinfer_cutlass"
assert tgt_config.speculative_config.moe_backend is None
draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "flashinfer_cutlass"
assert draft_config is tgt_config
def test_draft_model_moe_backend_default_auto():
"""When neither target nor draft set moe_backend explicitly, both should
default to 'auto'."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "auto"
assert tgt_config.speculative_config.moe_backend is None
draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "auto"
assert draft_config is tgt_config
def test_draft_model_engine_args_rejects_invalid_tp_argname():

View File

@@ -9,6 +9,7 @@ from pydantic import Field, SkipValidation, model_validator
from typing_extensions import Self
from vllm.config import LoadConfig
from vllm.config.kernel import MoEBackend
from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
@@ -93,6 +94,11 @@ class SpeculativeConfig:
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
moe_backend: MoEBackend | None = None
"""MoE backend to use for the draft model. When `None`, the draft model
inherits the target model's `--moe-backend` setting. Useful when the
drafter and generator require different MoE kernels (e.g. quantized
generator with unquantized drafter)."""
max_model_len: int | None = Field(default=None, ge=1)
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""

View File

@@ -6,10 +6,10 @@ import torch.nn as nn
from typing_extensions import override
from vllm.config import VllmConfig
from vllm.config.utils import replace
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
logger = init_logger(__name__)
@@ -50,16 +50,29 @@ class DraftModelProposer(SpecDecodeBaseProposer):
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
)
@override
def _create_draft_vllm_config(self) -> VllmConfig:
base = super()._create_draft_vllm_config()
spec = self.speculative_config
return replace(
base,
quant_config=None,
parallel_config=replace(
spec.draft_parallel_config,
rank=self.vllm_config.parallel_config.rank,
),
model_config=spec.draft_model_config,
)
@override
def _get_model(self) -> nn.Module:
# Draft models may be quantized or on different parallelism,
# so we load them with a modified vllm config
from vllm.compilation.backends import set_model_tag
temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config)
draft_vllm_config = self._create_draft_vllm_config()
with set_model_tag("draft_model"):
model = get_model(
vllm_config=temp_vllm_config,
vllm_config=draft_vllm_config,
prefix="draft_model",
)
return model

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from dataclasses import replace
from importlib.util import find_spec
from typing import cast
@@ -13,6 +12,7 @@ from vllm.config import (
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
replace,
)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
@@ -1213,6 +1213,21 @@ class SpecDecodeBaseProposer:
model = model.module
return model.__class__.__name__
def _create_draft_vllm_config(self) -> VllmConfig:
"""Return a VllmConfig with kernel-level overrides for the proposer.
Subclasses may override to apply additional config changes.
"""
spec_cfg = self.speculative_config
if spec_cfg.moe_backend is not None:
return replace(
self.vllm_config,
kernel_config=replace(
self.vllm_config.kernel_config,
moe_backend=spec_cfg.moe_backend,
),
)
return self.vllm_config
def _get_model(self) -> nn.Module:
"""
Default method to call get_model(). Can be overridden by subclasses which
@@ -1220,9 +1235,10 @@ class SpecDecodeBaseProposer:
"""
from vllm.compilation.backends import set_model_tag
draft_vllm_config = self._create_draft_vllm_config()
with set_model_tag("eagle_head"):
model = get_model(
vllm_config=self.vllm_config,
vllm_config=draft_vllm_config,
model_config=self.speculative_config.draft_model_config,
load_config=self.speculative_config.draft_load_config,
)

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig, replace
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import (
@@ -258,30 +257,6 @@ def compute_new_slot_mapping(
return new_slot_mapping
def create_vllm_config_for_draft_model(
target_model_vllm_config: VllmConfig,
) -> VllmConfig:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the drafter.
The vllm_config is useful when loading the draft model with get_model().
"""
old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
new_parallel_config = replace(
old_spec_config.draft_parallel_config, rank=old.parallel_config.rank
)
new: VllmConfig = replace(
old,
quant_config=None,
parallel_config=new_parallel_config,
model_config=old_spec_config.draft_model_config,
)
return new
def extend_all_queries_by_N(
common_attn_metadata: CommonAttentionMetadata,
N: int,