[Model][Spec Decode] Nemotron-H MTP and Mamba Speculative Decoding Support (#33726)
Signed-off-by: Shahar Mor <smor@nvidia.com> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Shahar Mor <smor@nvidia.com> Co-authored-by: Roi Koren <roik@nvidia.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
committed by
GitHub
parent
a9e15e040d
commit
f5972a872f
@@ -36,6 +36,7 @@ MTPModelTypes = Literal[
|
||||
"glm4_moe_lite_mtp",
|
||||
"glm_ocr_mtp",
|
||||
"ernie_mtp",
|
||||
"nemotron_h_mtp",
|
||||
"exaone_moe_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"qwen3_5_mtp",
|
||||
@@ -255,6 +256,19 @@ class SpeculativeConfig:
|
||||
{"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
|
||||
)
|
||||
|
||||
if (
|
||||
hf_config.model_type == "nemotron_h"
|
||||
and hasattr(hf_config, "num_nextn_predict_layers")
|
||||
and hf_config.num_nextn_predict_layers > 0
|
||||
):
|
||||
# Check if this is an MTP variant
|
||||
hf_config.model_type = "nemotron_h_mtp"
|
||||
if hf_config.model_type == "nemotron_h_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "qwen3_next":
|
||||
hf_config.model_type = "qwen3_next_mtp"
|
||||
if hf_config.model_type == "qwen3_next_mtp":
|
||||
@@ -325,7 +339,7 @@ class SpeculativeConfig:
|
||||
if self.target_model_config is None:
|
||||
raise ValueError("target_model_config must be present for mtp")
|
||||
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
|
||||
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
|
||||
# FIXME(luccafong): cudagraph with v32 MTP is not supported,
|
||||
# remove this when the issue is fixed.
|
||||
self.enforce_eager = True
|
||||
# use the draft model from the same model:
|
||||
@@ -427,7 +441,7 @@ class SpeculativeConfig:
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"Enabling num_speculative_tokens > 1 will run"
|
||||
"Enabling num_speculative_tokens > 1 will run "
|
||||
"multiple times of forward on same MTP layer"
|
||||
",which may result in lower acceptance rate"
|
||||
)
|
||||
@@ -712,6 +726,7 @@ class SpeculativeConfig:
|
||||
"hunyuan_vl",
|
||||
"hunyuan_v1_dense",
|
||||
"afmoe",
|
||||
"nemotron_h",
|
||||
]
|
||||
if (
|
||||
self.method == "eagle3"
|
||||
|
||||
@@ -395,6 +395,15 @@ class VllmConfig:
|
||||
]
|
||||
return hash_str
|
||||
|
||||
@property
|
||||
def num_speculative_tokens(self) -> int:
|
||||
if (
|
||||
self.speculative_config is not None
|
||||
and self.speculative_config.num_speculative_tokens is not None
|
||||
):
|
||||
return self.speculative_config.num_speculative_tokens
|
||||
return 0
|
||||
|
||||
@property
|
||||
def needs_dp_coordinator(self) -> bool:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user