use PyTorch SDPA for correctness (no sink bias in FMHA kernel yet)
This commit is contained in:
@@ -486,10 +486,23 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
|
||||
|
||||
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
|
||||
# k_full, v_full are (1, seq_len, hd) — already in (n_kv, N, hd) format
|
||||
attn_out = dsv4_attention(q_input, k_full, v_full) # (n_h, T, hd)
|
||||
# Use PyTorch SDPA for correctness verification (production FMHA has no sink bias)
|
||||
# dsv4_attention can be swapped back once sink bias is integrated into the kernel
|
||||
USE_SDPA = True # Set False to use production FMHA kernel
|
||||
if USE_SDPA:
|
||||
# PyTorch scaled_dot_product_attention
|
||||
# q: (n_h, T, hd), k: (1, seq_len, hd), v: (1, seq_len, hd)
|
||||
# Need to expand K/V for GQA: (1, seq_len, hd) → (n_h, seq_len, hd)
|
||||
k_expanded = k_full.expand(n_h, -1, -1) # (n_h, seq_len, hd)
|
||||
v_expanded = v_full.expand(n_h, -1, -1)
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
# For decode (T=1), use SDPA with is_causal=False (no causal mask needed)
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_input, k_expanded, v_expanded,
|
||||
scale=scale, is_causal=False) # (n_h, T, hd)
|
||||
else:
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
attn_out = dsv4_attention(q_input, k_full, v_full) # (n_h, T, hd)
|
||||
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
|
||||
# -- Inverse RoPE on attention output (paper §2.3.3) --
|
||||
|
||||
Reference in New Issue
Block a user