[New Model] BAGEL support (AR only) (#28439)

Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
汪志鹏
2025-12-15 14:58:23 +08:00
committed by GitHub
parent e3a1cd1c59
commit 1adeb3b84c
11 changed files with 777 additions and 0 deletions

View File

@@ -122,6 +122,8 @@ class Qwen2Attention(nn.Module):
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: dict[str, Any] | None = None,
qk_norm: bool = False,
rms_norm_eps: float = 1e-6,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -144,6 +146,7 @@ class Qwen2Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qk_norm = qk_norm
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -162,6 +165,11 @@ class Qwen2Attention(nn.Module):
prefix=f"{prefix}.o_proj",
)
# QK Normalization support (used in BAGEL and some other models)
if self.qk_norm:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position,
@@ -197,6 +205,23 @@ class Qwen2Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Apply QK normalization if enabled (before RoPE)
if self.qk_norm:
# Reshape to apply per-head normalization
# q shape: (total_tokens, q_size) -> (total_tokens, num_heads, head_dim)
total_tokens = q.shape[0]
q = q.view(total_tokens, self.num_heads, self.head_dim)
k = k.view(total_tokens, self.num_kv_heads, self.head_dim)
# Apply normalization
q = self.q_norm(q)
k = self.k_norm(k)
# Reshape back
q = q.view(total_tokens, self.q_size)
k = k.view(total_tokens, self.kv_size)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
@@ -227,6 +252,9 @@ class Qwen2DecoderLayer(nn.Module):
else:
attn_type = AttentionType.ENCODER_ONLY
# Check if QK normalization is enabled (used in BAGEL and some other models)
qk_norm = getattr(config, "qk_norm", False)
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -238,6 +266,8 @@ class Qwen2DecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
qk_norm=qk_norm,
rms_norm_eps=config.rms_norm_eps,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
@@ -480,6 +510,8 @@ class Qwen2Model(nn.Module):
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)