Add the support for the qwen3 next model (a hybrid attention model). (#24526)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Tao He
2025-09-11 15:32:09 +08:00
committed by GitHub
parent 2048c4e379
commit e93f4cc9e3
29 changed files with 2476 additions and 61 deletions

View File

@@ -1508,7 +1508,8 @@ class ModelConfig:
if (self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"
or self.hf_config.model_type == "glm4_moe_mtp"
or self.hf_config.model_type == "ernie_mtp"):
or self.hf_config.model_type == "ernie_mtp"
or self.hf_config.model_type == "qwen3_next_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
@@ -1571,15 +1572,28 @@ class ModelConfig:
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])
if layers_block_type_value is None and attn_type_list is None:
# Hybrid model Qwen3Next
layer_types_value = getattr(self.hf_config, "layer_types", None)
if layer_types_value is not None:
if getattr(block_type, "value", block_type) == "attention":
return sum(t == "full_attention"
for t in layer_types_value[start:end])
elif getattr(block_type, "value",
block_type) == "linear_attention":
return sum(t == "linear_attention"
for t in layer_types_value[start:end])
else:
return sum(t == getattr(block_type, "value", block_type)
for t in layer_types_value[start:end])
if (layers_block_type_value is None and attn_type_list is None
and layer_types_value is None):
raise ValueError(
"The model is an hybrid without a"
"layers_block_type or an attn_type_list in the hf_config,"
"cannot determine the num of "
"layers_block_type or an attn_type_list, or a layer_types "
"in the hf_config, cannot determine the num of "
f"{block_type.value} layers")
return sum(t == 1 for t in attn_type_list[start:end])
def get_mamba_chunk_size(self) -> Optional[int]:
"""
Returns the mamba chunk size if it exists
@@ -1866,7 +1880,7 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp"]
"ernie_mtp", "qwen3_next_mtp"]
@config
@@ -2007,7 +2021,15 @@ class SpeculativeConfig:
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
return hf_config
if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})
return hf_config
@@ -2028,9 +2050,13 @@ class SpeculativeConfig:
(self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or
self.target_model_config.hf_text_config.model_type in
("mimo","ernie4_5_moe")):
("mimo","ernie4_5_moe", "qwen3_next")):
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
# --quantization fp8 with a bf16 checkpoint.
if not self.quantization:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
else:
@@ -2140,6 +2166,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
@@ -2355,7 +2390,8 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")
def __repr__(self) -> str:
method = self.method