[Model][V1] Support Ernie MTP (#22169)
Signed-off-by: zhouchong <zhouchong03@baidu.com> Co-authored-by: zhouchong <zhouchong03@baidu.com>
This commit is contained in:
@@ -1463,7 +1463,8 @@ class ModelConfig:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
if (self.hf_text_config.model_type == "deepseek_mtp"
|
||||
or self.hf_config.model_type == "mimo_mtp"
|
||||
or self.hf_config.model_type == "glm4_moe_mtp"):
|
||||
or self.hf_config.model_type == "glm4_moe_mtp"
|
||||
or self.hf_config.model_type == "ernie_mtp"):
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_nextn_predict_layers", 0)
|
||||
else:
|
||||
@@ -1911,7 +1912,8 @@ class DeviceConfig:
|
||||
|
||||
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp"]
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||
"ernie_mtp"]
|
||||
|
||||
|
||||
@config
|
||||
@@ -2044,6 +2046,16 @@ class SpeculativeConfig:
|
||||
"architectures": ["Glm4MoeMTPModel"]
|
||||
})
|
||||
|
||||
if hf_config.model_type == "ernie4_5_moe":
|
||||
hf_config.model_type = "ernie_mtp"
|
||||
if hf_config.model_type == "ernie_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["ErnieMTPModel"]
|
||||
})
|
||||
return hf_config
|
||||
|
||||
return hf_config
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -2062,8 +2074,8 @@ class SpeculativeConfig:
|
||||
if self.target_model_config and \
|
||||
(self.target_model_config.hf_text_config.model_type \
|
||||
== "deepseek_v3" or
|
||||
self.target_model_config.hf_text_config.model_type \
|
||||
== "mimo"):
|
||||
self.target_model_config.hf_text_config.model_type in
|
||||
("mimo","ernie4_5_moe")):
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
@@ -2161,6 +2173,15 @@ class SpeculativeConfig:
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"ernie_mtp"):
|
||||
self.method = "ernie_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Ernie MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
raise NotImplementedError(
|
||||
@@ -2376,7 +2397,7 @@ class SpeculativeConfig:
|
||||
return self.num_speculative_tokens
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "deepseek_mtp")
|
||||
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
|
||||
Reference in New Issue
Block a user