fix: FMHA K/V tensor shape (was permuting cache), add q_a_norm and kv_norm

This commit is contained in:
2026-05-31 03:04:53 +00:00
parent 3f12bbc374
commit a262492e51

View File

@@ -429,21 +429,33 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
# -- RMSNorm (pre-norm before attention) --
x_normed = attn_norm.forward(x_in) # (T, H) BF16
# -- Q projection: q_a (low-rank down) → q_b (low-rank up) --
# -- Q projection: q_a (low-rank down) → q_a_norm → q_b (low-rank up) --
c_Q = nvfp4_linear(x_normed,
w[f"{pre}.q_a_proj.weight"],
w[f"{pre}.q_a_proj.weight_scale"],
w[f"{pre}.q_a_proj.weight_scale_2"]) # (T, dc)
# Q norm (RMSNorm after q_a, before q_b)
q_norm_w = w.get(f"{pre}.q_a_norm.weight")
if q_norm_w is not None:
c_Q_f = c_Q.float()
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16()
q = nvfp4_linear(c_Q,
w[f"{pre}.q_b_proj.weight"],
w[f"{pre}.q_b_proj.weight_scale"],
w[f"{pre}.q_b_proj.weight_scale_2"]) # (T, n_h * hd)
# -- KV projection (MQA: 1 KV head) --
# -- KV projection (MQA: 1 KV head) + KV norm --
kv = nvfp4_linear(x_normed,
w[f"{pre}.kv_proj.weight"],
w[f"{pre}.kv_proj.weight_scale"],
w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd) — 1 KV head, no split
w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd)
# KV norm (RMSNorm after kv_proj)
kv_norm_w = w.get(f"{pre}.kv_norm.weight")
if kv_norm_w is not None:
kv_f = kv.float()
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16()
# -- Reshape for attention --
q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd)
@@ -469,9 +481,8 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
from dsv4.kernels.attention.production import dsv4_attention
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
k_input = k_full.permute(1, 0, 2) # (1, seq_len, hd) — already RoPE'd
v_input = v_full.permute(1, 0, 2) # (1, seq_len, hd) — K=V, RoPE'd
attn_out = dsv4_attention(q_input, k_input, v_input) # (n_h, T, hd)
# k_full, v_full are (1, seq_len, hd) — already in (n_kv, N, hd) format
attn_out = dsv4_attention(q_input, k_full, v_full) # (n_h, T, hd)
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
# -- Inverse RoPE on attention output (paper §2.3.3) --