fix: define q_input before USE_SDPA branch
This commit is contained in:
@@ -486,6 +486,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
|
||||
|
||||
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
|
||||
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
|
||||
# Use PyTorch SDPA for correctness verification (production FMHA has no sink bias)
|
||||
# dsv4_attention can be swapped back once sink bias is integrated into the kernel
|
||||
USE_SDPA = True # Set False to use production FMHA kernel
|
||||
@@ -496,7 +497,6 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
k_expanded = k_full.expand(n_h, -1, -1) # (n_h, seq_len, hd)
|
||||
v_expanded = v_full.expand(n_h, -1, -1)
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
# For decode (T=1), use SDPA with is_causal=False (no causal mask needed)
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_input, k_expanded, v_expanded,
|
||||
scale=scale, is_causal=False) # (n_h, T, hd)
|
||||
|
||||
Reference in New Issue
Block a user