fix: FMHA K/V tensor shape (was permuting cache), add q_a_norm and kv_norm
This commit is contained in:
@@ -429,21 +429,33 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
# -- RMSNorm (pre-norm before attention) --
|
||||
x_normed = attn_norm.forward(x_in) # (T, H) BF16
|
||||
|
||||
# -- Q projection: q_a (low-rank down) → q_b (low-rank up) --
|
||||
# -- Q projection: q_a (low-rank down) → q_a_norm → q_b (low-rank up) --
|
||||
c_Q = nvfp4_linear(x_normed,
|
||||
w[f"{pre}.q_a_proj.weight"],
|
||||
w[f"{pre}.q_a_proj.weight_scale"],
|
||||
w[f"{pre}.q_a_proj.weight_scale_2"]) # (T, dc)
|
||||
# Q norm (RMSNorm after q_a, before q_b)
|
||||
q_norm_w = w.get(f"{pre}.q_a_norm.weight")
|
||||
if q_norm_w is not None:
|
||||
c_Q_f = c_Q.float()
|
||||
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16()
|
||||
q = nvfp4_linear(c_Q,
|
||||
w[f"{pre}.q_b_proj.weight"],
|
||||
w[f"{pre}.q_b_proj.weight_scale"],
|
||||
w[f"{pre}.q_b_proj.weight_scale_2"]) # (T, n_h * hd)
|
||||
|
||||
# -- KV projection (MQA: 1 KV head) --
|
||||
# -- KV projection (MQA: 1 KV head) + KV norm --
|
||||
kv = nvfp4_linear(x_normed,
|
||||
w[f"{pre}.kv_proj.weight"],
|
||||
w[f"{pre}.kv_proj.weight_scale"],
|
||||
w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd) — 1 KV head, no split
|
||||
w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd)
|
||||
# KV norm (RMSNorm after kv_proj)
|
||||
kv_norm_w = w.get(f"{pre}.kv_norm.weight")
|
||||
if kv_norm_w is not None:
|
||||
kv_f = kv.float()
|
||||
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16()
|
||||
|
||||
# -- Reshape for attention --
|
||||
q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd)
|
||||
@@ -469,9 +481,8 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
|
||||
k_input = k_full.permute(1, 0, 2) # (1, seq_len, hd) — already RoPE'd
|
||||
v_input = v_full.permute(1, 0, 2) # (1, seq_len, hd) — K=V, RoPE'd
|
||||
attn_out = dsv4_attention(q_input, k_input, v_input) # (n_h, T, hd)
|
||||
# k_full, v_full are (1, seq_len, hd) — already in (n_kv, N, hd) format
|
||||
attn_out = dsv4_attention(q_input, k_full, v_full) # (n_h, T, hd)
|
||||
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
|
||||
# -- Inverse RoPE on attention output (paper §2.3.3) --
|
||||
|
||||
Reference in New Issue
Block a user