fix: define q_input before USE_SDPA branch

This commit is contained in:
2026-05-31 03:45:09 +00:00
parent cd073ad867
commit 1905f19b8d

View File

@@ -486,6 +486,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
# Use PyTorch SDPA for correctness verification (production FMHA has no sink bias)
# dsv4_attention can be swapped back once sink bias is integrated into the kernel
USE_SDPA = True # Set False to use production FMHA kernel
@@ -496,7 +497,6 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
k_expanded = k_full.expand(n_h, -1, -1) # (n_h, seq_len, hd)
v_expanded = v_full.expand(n_h, -1, -1)
scale = 1.0 / math.sqrt(hd)
# For decode (T=1), use SDPA with is_causal=False (no causal mask needed)
attn_out = torch.nn.functional.scaled_dot_product_attention(
q_input, k_expanded, v_expanded,
scale=scale, is_causal=False) # (n_h, T, hd)