use PyTorch SDPA for correctness (no sink bias in FMHA kernel yet)

This commit is contained in:
2026-05-31 03:42:03 +00:00
parent 171a9e0d10
commit cd073ad867

View File

@@ -486,10 +486,23 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
from dsv4.kernels.attention.production import dsv4_attention
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
# k_full, v_full are (1, seq_len, hd) — already in (n_kv, N, hd) format
attn_out = dsv4_attention(q_input, k_full, v_full) # (n_h, T, hd)
# Use PyTorch SDPA for correctness verification (production FMHA has no sink bias)
# dsv4_attention can be swapped back once sink bias is integrated into the kernel
USE_SDPA = True # Set False to use production FMHA kernel
if USE_SDPA:
# PyTorch scaled_dot_product_attention
# q: (n_h, T, hd), k: (1, seq_len, hd), v: (1, seq_len, hd)
# Need to expand K/V for GQA: (1, seq_len, hd) → (n_h, seq_len, hd)
k_expanded = k_full.expand(n_h, -1, -1) # (n_h, seq_len, hd)
v_expanded = v_full.expand(n_h, -1, -1)
scale = 1.0 / math.sqrt(hd)
# For decode (T=1), use SDPA with is_causal=False (no causal mask needed)
attn_out = torch.nn.functional.scaled_dot_product_attention(
q_input, k_expanded, v_expanded,
scale=scale, is_causal=False) # (n_h, T, hd)
else:
from dsv4.kernels.attention.production import dsv4_attention
attn_out = dsv4_attention(q_input, k_full, v_full) # (n_h, T, hd)
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
# -- Inverse RoPE on attention output (paper §2.3.3) --