diff --git a/single_shot_inference.py b/single_shot_inference.py index 453fc083..41852c18 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -486,10 +486,23 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V # -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) -- - from dsv4.kernels.attention.production import dsv4_attention - q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd) - # k_full, v_full are (1, seq_len, hd) — already in (n_kv, N, hd) format - attn_out = dsv4_attention(q_input, k_full, v_full) # (n_h, T, hd) + # Use PyTorch SDPA for correctness verification (production FMHA has no sink bias) + # dsv4_attention can be swapped back once sink bias is integrated into the kernel + USE_SDPA = True # Set False to use production FMHA kernel + if USE_SDPA: + # PyTorch scaled_dot_product_attention + # q: (n_h, T, hd), k: (1, seq_len, hd), v: (1, seq_len, hd) + # Need to expand K/V for GQA: (1, seq_len, hd) → (n_h, seq_len, hd) + k_expanded = k_full.expand(n_h, -1, -1) # (n_h, seq_len, hd) + v_expanded = v_full.expand(n_h, -1, -1) + scale = 1.0 / math.sqrt(hd) + # For decode (T=1), use SDPA with is_causal=False (no causal mask needed) + attn_out = torch.nn.functional.scaled_dot_product_attention( + q_input, k_expanded, v_expanded, + scale=scale, is_causal=False) # (n_h, T, hd) + else: + from dsv4.kernels.attention.production import dsv4_attention + attn_out = dsv4_attention(q_input, k_full, v_full) # (n_h, T, hd) attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd) # -- Inverse RoPE on attention output (paper §2.3.3) --