diff --git a/single_shot_inference.py b/single_shot_inference.py index 6f244374..7e00051d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -369,7 +369,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) -- q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd) # Use PyTorch SDPA for correctness verification - USE_SDPA = True # Use SDPA with sinks for correctness + USE_SDPA = False # Use production FMHA kernel (better residual, no sinks) if USE_SDPA: # Expand K/V for GQA: (1, seq_len, hd) → (n_h, seq_len, hd) k_expanded = k_full.expand(n_h, -1, -1).contiguous() # (n_h, seq_len, hd)