switch to production FMHA for full run

This commit is contained in:
2026-05-31 04:51:16 +00:00
parent 738088cf49
commit 04dd7545b3

View File

@@ -369,7 +369,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
# Use PyTorch SDPA for correctness verification
USE_SDPA = True # Use SDPA with sinks for correctness
USE_SDPA = False # Use production FMHA kernel (better residual, no sinks)
if USE_SDPA:
# Expand K/V for GQA: (1, seq_len, hd) → (n_h, seq_len, hd)
k_expanded = k_full.expand(n_h, -1, -1).contiguous() # (n_h, seq_len, hd)