switch to production FMHA for full run
This commit is contained in:
@@ -369,7 +369,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
|
||||
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
|
||||
# Use PyTorch SDPA for correctness verification
|
||||
USE_SDPA = True # Use SDPA with sinks for correctness
|
||||
USE_SDPA = False # Use production FMHA kernel (better residual, no sinks)
|
||||
if USE_SDPA:
|
||||
# Expand K/V for GQA: (1, seq_len, hd) → (n_h, seq_len, hd)
|
||||
k_expanded = k_full.expand(n_h, -1, -1).contiguous() # (n_h, seq_len, hd)
|
||||
|
||||
Reference in New Issue
Block a user