[Qwen3] Enable dual-chunk-attention support for Qwen3 models. (#21924)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
Tao He
2025-08-07 10:58:08 +08:00
committed by GitHub
parent 6b47ef24de
commit 7377131a2c
2 changed files with 60 additions and 31 deletions

View File

@@ -185,6 +185,7 @@ class Qwen3MoeAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -208,6 +209,7 @@ class Qwen3MoeAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(hidden_size,
self.head_dim,
@@ -229,14 +231,21 @@ class Qwen3MoeAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {},
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
@@ -280,6 +289,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -293,6 +305,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
dual_chunk_attention_config=dual_chunk_attention_config,
)
# `mlp_only_layers` in the config.