[Model][V1] Support Ernie MTP (#22169)

Signed-off-by: zhouchong <zhouchong03@baidu.com>
Co-authored-by: zhouchong <zhouchong03@baidu.com>
This commit is contained in:
xyxinyang
2025-08-20 20:41:55 +08:00
committed by GitHub
parent 50df09fe13
commit 7cd17e22d7
6 changed files with 320 additions and 7 deletions

View File

@@ -1463,7 +1463,8 @@ class ModelConfig:
from vllm.distributed.utils import get_pp_indices
if (self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"
or self.hf_config.model_type == "glm4_moe_mtp"):
or self.hf_config.model_type == "glm4_moe_mtp"
or self.hf_config.model_type == "ernie_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
@@ -1911,7 +1912,8 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp"]
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp"]
@config
@@ -2044,6 +2046,16 @@ class SpeculativeConfig:
"architectures": ["Glm4MoeMTPModel"]
})
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
return hf_config
return hf_config
def __post_init__(self):
@@ -2062,8 +2074,8 @@ class SpeculativeConfig:
if self.target_model_config and \
(self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or
self.target_model_config.hf_text_config.model_type \
== "mimo"):
self.target_model_config.hf_text_config.model_type in
("mimo","ernie4_5_moe")):
# use the draft model from the same model:
self.model = self.target_model_config.model
elif self.method in ("ngram", "[ngram]"):
@@ -2161,6 +2173,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
@@ -2376,7 +2397,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp")
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
def __repr__(self) -> str:
method = self.method