[New Model] BAGEL support (AR only) (#28439)
Signed-off-by: princepride <wangzhipeng628@gmail.com> Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -122,6 +122,8 @@ class Qwen2Attention(nn.Module):
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: dict[str, Any] | None = None,
|
||||
qk_norm: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -144,6 +146,7 @@ class Qwen2Attention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.dual_chunk_attention_config = dual_chunk_attention_config
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -162,6 +165,11 @@ class Qwen2Attention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
# QK Normalization support (used in BAGEL and some other models)
|
||||
if self.qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
max_position=max_position,
|
||||
@@ -197,6 +205,23 @@ class Qwen2Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Apply QK normalization if enabled (before RoPE)
|
||||
if self.qk_norm:
|
||||
# Reshape to apply per-head normalization
|
||||
# q shape: (total_tokens, q_size) -> (total_tokens, num_heads, head_dim)
|
||||
total_tokens = q.shape[0]
|
||||
q = q.view(total_tokens, self.num_heads, self.head_dim)
|
||||
k = k.view(total_tokens, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# Apply normalization
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Reshape back
|
||||
q = q.view(total_tokens, self.q_size)
|
||||
k = k.view(total_tokens, self.kv_size)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
@@ -227,6 +252,9 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
# Check if QK normalization is enabled (used in BAGEL and some other models)
|
||||
qk_norm = getattr(config, "qk_norm", False)
|
||||
|
||||
self.self_attn = Qwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@@ -238,6 +266,8 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
qk_norm=qk_norm,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -480,6 +510,8 @@ class Qwen2Model(nn.Module):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user