[Model][Spec Decode] Nemotron-H MTP and Mamba Speculative Decoding Support (#33726)

Signed-off-by: Shahar Mor <smor@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Shahar Mor <smor@nvidia.com>
Co-authored-by: Roi Koren <roik@nvidia.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Benjamin Chislett
2026-02-24 12:49:56 -05:00
committed by GitHub
parent a9e15e040d
commit f5972a872f
19 changed files with 799 additions and 157 deletions

View File

@@ -36,6 +36,7 @@ MTPModelTypes = Literal[
"glm4_moe_lite_mtp",
"glm_ocr_mtp",
"ernie_mtp",
"nemotron_h_mtp",
"exaone_moe_mtp",
"qwen3_next_mtp",
"qwen3_5_mtp",
@@ -255,6 +256,19 @@ class SpeculativeConfig:
{"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
)
if (
hf_config.model_type == "nemotron_h"
and hasattr(hf_config, "num_nextn_predict_layers")
and hf_config.num_nextn_predict_layers > 0
):
# Check if this is an MTP variant
hf_config.model_type = "nemotron_h_mtp"
if hf_config.model_type == "nemotron_h_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update(
{"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]}
)
if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
@@ -325,7 +339,7 @@ class SpeculativeConfig:
if self.target_model_config is None:
raise ValueError("target_model_config must be present for mtp")
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
# FIXME(luccafong): cudagraph with v32 MTP is not supported,
# remove this when the issue is fixed.
self.enforce_eager = True
# use the draft model from the same model:
@@ -427,7 +441,7 @@ class SpeculativeConfig:
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"Enabling num_speculative_tokens > 1 will run"
"Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer"
",which may result in lower acceptance rate"
)
@@ -712,6 +726,7 @@ class SpeculativeConfig:
"hunyuan_vl",
"hunyuan_v1_dense",
"afmoe",
"nemotron_h",
]
if (
self.method == "eagle3"

View File

@@ -395,6 +395,15 @@ class VllmConfig:
]
return hash_str
@property
def num_speculative_tokens(self) -> int:
if (
self.speculative_config is not None
and self.speculative_config.num_speculative_tokens is not None
):
return self.speculative_config.num_speculative_tokens
return 0
@property
def needs_dp_coordinator(self) -> bool:
"""