[spec decode] Consolidate speculative decode method name for MTP (#25232)

Signed-off-by: zixi-qi <qizixi@meta.com>
This commit is contained in:
qizixi
2025-09-26 15:27:05 -07:00
committed by GitHub
parent cf89202855
commit c70ac4b8ff
6 changed files with 287 additions and 40 deletions

View File

@@ -32,7 +32,9 @@ logger = init_logger(__name__)
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
"longcat_flash_mtp"]
"longcat_flash_mtp", "mtp"]
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp")
@config
@@ -207,11 +209,16 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.method in MTP_MODEL_TYPES:
logger.warning("method `%s` is deprecated and replaced with mtp.",
self.method)
self.method = "mtp"
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
if (self.target_model_config
and self.target_model_config.hf_text_config.model_type
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
if self.method == "mtp":
assert (
self.target_model_config
is not None), "target_model_config must be present for mtp"
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
@@ -312,31 +319,13 @@ class SpeculativeConfig:
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
self.method = "deepseek_mtp"
in MTP_MODEL_TYPES):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
"Enabling num_speculative_tokens > 1 will run" \
"multiple times of forward on same MTP layer" \
",which may result in lower acceptance rate" \
)
elif (self.draft_model_config.hf_config.model_type
in ("longcat_flash_mtp")):
@@ -353,7 +342,7 @@ class SpeculativeConfig:
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")
"eagle, or mtp.")
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
@@ -562,8 +551,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp")
return self.method in ("eagle", "eagle3", "mtp")
def __repr__(self) -> str:
method = self.method