[Feature] Support per-draft-model MoE backend via --speculative-config (#37880)
Signed-off-by: Andrii Skliar <askliar@nvidia.com> Signed-off-by: [Andrii Skliar] <askliar@nvidia.com> Co-authored-by: Andrii Skliar <askliar@nvidia.com>
This commit is contained in:
@@ -20,12 +20,11 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.benchmarks.datasets import InstructCoderDataset
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, replace
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
|
||||
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
@@ -919,13 +918,104 @@ def test_draft_model_engine_args_tensor_parallelism():
|
||||
"draft_tensor_parallel_size": 1, # <<< valid arg name
|
||||
},
|
||||
)
|
||||
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
|
||||
assert tgt_vllm_config.quant_config.get_name() == "fp8"
|
||||
target_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert target_config.parallel_config.tensor_parallel_size == 2
|
||||
assert target_config.quant_config.get_name() == "fp8"
|
||||
|
||||
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
|
||||
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
|
||||
assert draft_vllm_config.quant_config is None
|
||||
speculative_config = target_config.speculative_config
|
||||
draft_config: VllmConfig = replace(
|
||||
target_config,
|
||||
quant_config=None,
|
||||
parallel_config=replace(
|
||||
speculative_config.draft_parallel_config,
|
||||
rank=target_config.parallel_config.rank,
|
||||
),
|
||||
model_config=speculative_config.draft_model_config,
|
||||
)
|
||||
assert draft_config.parallel_config.tensor_parallel_size == 1
|
||||
assert draft_config.quant_config is None
|
||||
|
||||
|
||||
def _apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig:
|
||||
"""Replicate SpecDecodeBaseProposer._create_draft_vllm_config logic
|
||||
so we can test it without instantiating a full proposer."""
|
||||
spec_cfg = vllm_config.speculative_config
|
||||
if spec_cfg.moe_backend is not None:
|
||||
return replace(
|
||||
vllm_config,
|
||||
kernel_config=replace(
|
||||
vllm_config.kernel_config,
|
||||
moe_backend=spec_cfg.moe_backend,
|
||||
),
|
||||
)
|
||||
return vllm_config
|
||||
|
||||
|
||||
def test_draft_model_moe_backend_override():
|
||||
"""When moe_backend is set in speculative_config, the draft VllmConfig
|
||||
should use it while the target keeps its own setting."""
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-1.7B",
|
||||
tensor_parallel_size=1,
|
||||
moe_backend="flashinfer_trtllm",
|
||||
speculative_config={
|
||||
"model": "Qwen/Qwen3-0.6B",
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": 3,
|
||||
"moe_backend": "triton",
|
||||
},
|
||||
)
|
||||
tgt_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
|
||||
assert tgt_config.speculative_config.moe_backend == "triton"
|
||||
|
||||
draft_config = _apply_draft_moe_backend(tgt_config)
|
||||
assert draft_config.kernel_config.moe_backend == "triton"
|
||||
# Target config must be unaffected.
|
||||
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
|
||||
|
||||
|
||||
def test_draft_model_moe_backend_inherits_target():
|
||||
"""When moe_backend is not set in speculative_config, the draft should
|
||||
inherit the target's moe_backend."""
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-1.7B",
|
||||
tensor_parallel_size=1,
|
||||
moe_backend="flashinfer_cutlass",
|
||||
speculative_config={
|
||||
"model": "Qwen/Qwen3-0.6B",
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
)
|
||||
tgt_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert tgt_config.kernel_config.moe_backend == "flashinfer_cutlass"
|
||||
assert tgt_config.speculative_config.moe_backend is None
|
||||
|
||||
draft_config = _apply_draft_moe_backend(tgt_config)
|
||||
assert draft_config.kernel_config.moe_backend == "flashinfer_cutlass"
|
||||
assert draft_config is tgt_config
|
||||
|
||||
|
||||
def test_draft_model_moe_backend_default_auto():
|
||||
"""When neither target nor draft set moe_backend explicitly, both should
|
||||
default to 'auto'."""
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-1.7B",
|
||||
tensor_parallel_size=1,
|
||||
speculative_config={
|
||||
"model": "Qwen/Qwen3-0.6B",
|
||||
"method": "draft_model",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
)
|
||||
tgt_config: VllmConfig = engine_args.create_engine_config()
|
||||
assert tgt_config.kernel_config.moe_backend == "auto"
|
||||
assert tgt_config.speculative_config.moe_backend is None
|
||||
|
||||
draft_config = _apply_draft_moe_backend(tgt_config)
|
||||
assert draft_config.kernel_config.moe_backend == "auto"
|
||||
assert draft_config is tgt_config
|
||||
|
||||
|
||||
def test_draft_model_engine_args_rejects_invalid_tp_argname():
|
||||
|
||||
Reference in New Issue
Block a user