[Feature] Support per-draft-model MoE backend via --speculative-config (#37880)

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: [Andrii Skliar] <askliar@nvidia.com>
Co-authored-by: Andrii Skliar <askliar@nvidia.com>
This commit is contained in:
Andrii Skliar
2026-03-25 15:31:52 +01:00
committed by GitHub
parent a1a2566447
commit cd7643015e
5 changed files with 140 additions and 40 deletions

View File

@@ -20,12 +20,11 @@ from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.benchmarks.datasets import InstructCoderDataset
from vllm.config import VllmConfig
from vllm.config import VllmConfig, replace
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
MTP_SIMILARITY_RATE = 0.8
@@ -919,13 +918,104 @@ def test_draft_model_engine_args_tensor_parallelism():
"draft_tensor_parallel_size": 1, # <<< valid arg name
},
)
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
assert tgt_vllm_config.quant_config.get_name() == "fp8"
target_config: VllmConfig = engine_args.create_engine_config()
assert target_config.parallel_config.tensor_parallel_size == 2
assert target_config.quant_config.get_name() == "fp8"
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
assert draft_vllm_config.quant_config is None
speculative_config = target_config.speculative_config
draft_config: VllmConfig = replace(
target_config,
quant_config=None,
parallel_config=replace(
speculative_config.draft_parallel_config,
rank=target_config.parallel_config.rank,
),
model_config=speculative_config.draft_model_config,
)
assert draft_config.parallel_config.tensor_parallel_size == 1
assert draft_config.quant_config is None
def _apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig:
"""Replicate SpecDecodeBaseProposer._create_draft_vllm_config logic
so we can test it without instantiating a full proposer."""
spec_cfg = vllm_config.speculative_config
if spec_cfg.moe_backend is not None:
return replace(
vllm_config,
kernel_config=replace(
vllm_config.kernel_config,
moe_backend=spec_cfg.moe_backend,
),
)
return vllm_config
def test_draft_model_moe_backend_override():
"""When moe_backend is set in speculative_config, the draft VllmConfig
should use it while the target keeps its own setting."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
moe_backend="flashinfer_trtllm",
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
"moe_backend": "triton",
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
assert tgt_config.speculative_config.moe_backend == "triton"
draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "triton"
# Target config must be unaffected.
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
def test_draft_model_moe_backend_inherits_target():
"""When moe_backend is not set in speculative_config, the draft should
inherit the target's moe_backend."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
moe_backend="flashinfer_cutlass",
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "flashinfer_cutlass"
assert tgt_config.speculative_config.moe_backend is None
draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "flashinfer_cutlass"
assert draft_config is tgt_config
def test_draft_model_moe_backend_default_auto():
"""When neither target nor draft set moe_backend explicitly, both should
default to 'auto'."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "auto"
assert tgt_config.speculative_config.moe_backend is None
draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "auto"
assert draft_config is tgt_config
def test_draft_model_engine_args_rejects_invalid_tp_argname():