[spec decode] Consolidate speculative decode method name for MTP (#25232)
Signed-off-by: zixi-qi <qizixi@meta.com>
This commit is contained in:
@@ -32,7 +32,9 @@ logger = init_logger(__name__)
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
|
||||
"longcat_flash_mtp"]
|
||||
"longcat_flash_mtp", "mtp"]
|
||||
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp", "longcat_flash_mtp")
|
||||
|
||||
|
||||
@config
|
||||
@@ -207,11 +209,16 @@ class SpeculativeConfig:
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.method in MTP_MODEL_TYPES:
|
||||
logger.warning("method `%s` is deprecated and replaced with mtp.",
|
||||
self.method)
|
||||
self.method = "mtp"
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
if (self.target_model_config
|
||||
and self.target_model_config.hf_text_config.model_type
|
||||
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
|
||||
if self.method == "mtp":
|
||||
assert (
|
||||
self.target_model_config
|
||||
is not None), "target_model_config must be present for mtp"
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
# Align the quantization of draft model for cases such as
|
||||
@@ -312,31 +319,13 @@ class SpeculativeConfig:
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||
self.method = "deepseek_mtp"
|
||||
in MTP_MODEL_TYPES):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Deepseek MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"ernie_mtp"):
|
||||
self.method = "ernie_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Ernie MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"qwen3_next_mtp"):
|
||||
self.method = "qwen3_next_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Qwen3Next MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
"Enabling num_speculative_tokens > 1 will run" \
|
||||
"multiple times of forward on same MTP layer" \
|
||||
",which may result in lower acceptance rate" \
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("longcat_flash_mtp")):
|
||||
@@ -353,7 +342,7 @@ class SpeculativeConfig:
|
||||
"Speculative decoding with draft model is not "
|
||||
"supported yet. Please consider using other "
|
||||
"speculative decoding methods such as ngram, medusa, "
|
||||
"eagle, or deepseek_mtp.")
|
||||
"eagle, or mtp.")
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
@@ -562,8 +551,7 @@ class SpeculativeConfig:
|
||||
return self.num_speculative_tokens
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp", "longcat_flash_mtp")
|
||||
return self.method in ("eagle", "eagle3", "mtp")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
|
||||
Reference in New Issue
Block a user