FMHA kernel fix: N_orig vs N_padded — correct softmax masking for seq_len < 128

ROOT CAUSE: fmha_multitile_op.py padded N to 128 for TMA alignment
but then passed the PADDED N to the kernel as s_k (logical KV length).
This told the kernel all 128 entries were valid, so softmax ran over
zeros, diluting the result (e.g. 1 valid entry → softmax weight 1/128).

FIX: Pass N_orig (true sequence length) as s_k for softmax masking,
and N_padded (physical size) only for TMA descriptor creation.
The kernel's existing col < kv_len guard correctly excludes padded
entries from row_max and exp_sum calculations.

Files changed:
- fmha_multitile_capi.cu: accept N_orig + N_padded, use N_orig for
  params.s_k and N_padded for TMA descriptors
- fmha_multitile_op.py: pass N_orig and N_padded separately
- single_shot_inference.py: removed SDPA fallback (kernel now correct)
This commit is contained in:
2026-05-31 22:52:39 +00:00
parent d40821c843
commit 92200367f3
3 changed files with 34 additions and 49 deletions

View File

@@ -26,7 +26,7 @@ int fmha_multitile_decode_launch(
const void* v_ptr,
void* o_ptr,
void* lse_ptr,
int batch, int n_h, int T, int N, int hd,
int batch, int n_h, int T, int N_orig, int N_padded, int hd,
int q_head_stride, int q_batch_stride,
int k_head_stride, int k_batch_stride,
int v_head_stride, int v_batch_stride,
@@ -34,6 +34,10 @@ int fmha_multitile_decode_launch(
int lse_head_stride, int lse_batch_stride,
float scale
) {
// N_orig: logical KV length (used for softmax masking in kernel)
// N_padded: physical KV length (used for TMA descriptor creation)
// When N_orig < N_padded, the extra rows are zero-padded and
// correctly excluded from softmax by the kernel's col < kv_len guard.
size_t desc_count = n_h * batch;
CUtensorMap* d_tma_k;
@@ -47,16 +51,16 @@ int fmha_multitile_decode_launch(
const bf16_t* v_head = (const bf16_t*)v_ptr + h * v_head_stride + b * v_batch_stride;
int idx = b * n_h + h;
// K: (N, hd), TMA tile (128, 16)
// K: (N_padded, hd), TMA tile (128, 16) — use physical size for TMA
CUtensorMap h_desc;
if (!create_tma_desc_2d_bf16(&h_desc, k_head, N, hd, 128, 16)) {
if (!create_tma_desc_2d_bf16(&h_desc, k_head, N_padded, hd, 128, 16)) {
cudaFree(d_tma_k); cudaFree(d_tma_v);
return -1;
}
cudaMemcpy(d_tma_k + idx, &h_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
// V: (hd, N), TMA tile (16, 16)
if (!create_tma_desc_2d_bf16(&h_desc, v_head, hd, N, 16, 16)) {
// V: (hd, N_padded), TMA tile (16, 16) — use physical size for TMA
if (!create_tma_desc_2d_bf16(&h_desc, v_head, hd, N_padded, 16, 16)) {
cudaFree(d_tma_k); cudaFree(d_tma_v);
return -1;
}
@@ -70,7 +74,7 @@ int fmha_multitile_decode_launch(
params.tma_v = d_tma_v;
params.o = (bf16_t*)o_ptr;
params.lse = (float*)lse_ptr;
params.s_k = N;
params.s_k = N_orig; // Logical KV length — kernel uses this for softmax masking
params.T = T;
params.n_h = n_h;
params.scale = scale;

View File

@@ -100,13 +100,17 @@ def fmha_multitile_decode_raw(
k = k.repeat_interleave(q_per_kv, dim=1)
v = v.repeat_interleave(q_per_kv, dim=1)
# Pad N to multiple of 128
# Pad N to multiple of 128 (TMA descriptor alignment)
# CRITICAL: We track the ORIGINAL N (N_orig) separately from N_padded.
# The kernel uses s_k=N_orig as the logical KV length for softmax masking.
# Only the K/V tensors are padded (with zeros) for TMA alignment.
N_orig = N
N_padded = ((N + 127) // 128) * 128
if N < N_padded:
pad = N_padded - N
k = torch.cat([k, torch.zeros(B, k.shape[1], pad, hd, dtype=torch.bfloat16, device=k.device)], dim=2)
v = torch.cat([v, torch.zeros(v.shape[0], v.shape[1], hd, pad, dtype=torch.bfloat16, device=v.device)], dim=3)
N = N_padded
N = N_padded # N is now the physical size (padded)
k = k.contiguous()
v = v.contiguous()
@@ -121,7 +125,10 @@ def fmha_multitile_decode_raw(
ctypes.c_void_p(v.data_ptr()),
ctypes.c_void_p(o.data_ptr()),
ctypes.c_void_p(lse.data_ptr()),
ctypes.c_int(B), ctypes.c_int(n_h), ctypes.c_int(T), ctypes.c_int(N), ctypes.c_int(hd),
ctypes.c_int(B), ctypes.c_int(n_h), ctypes.c_int(T),
ctypes.c_int(N_orig), # s_k: logical KV length (for softmax masking)
ctypes.c_int(N_padded), # N_padded: physical KV length (for TMA descriptors)
ctypes.c_int(hd),
ctypes.c_int(q.stride(1)), ctypes.c_int(q.stride(0)),
ctypes.c_int(k.stride(1)), ctypes.c_int(k.stride(0)),
ctypes.c_int(v.stride(1)), ctypes.c_int(v.stride(0)),

View File

@@ -372,53 +372,27 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
q_heads: (T, n_h, hd), all_kv: (seq_len, hd)
Returns: (T, n_h, hd) BF16
KERNEL LIMITATION: The 6-warp TMA FMHA kernel pads N to 128.
When seq_len < 128, the zero-padded entries dilute the softmax
(e.g. seq_len=1 gives softmax over 128, 127 of which are zero,
reducing max attention weight from 1.0 to 1/128). This must be
fixed in the kernel (skip zero-padded entries in softmax). Until
then, we use PyTorch scaled_dot_product_attention for short
sequences where the padding would dominate.
TODO: Fix FMHA kernel to handle N < 128 correctly (mask padded
entries from softmax). This is a kernel bug, not a design choice.
The 6-warp TMA FMHA kernel correctly handles N < 128:
K/V are padded to 128 for TMA alignment, but the kernel receives
the true s_k and masks padded entries in softmax (col < kv_len guard).
Fixed in fmha_multitile_capi.cu: N_orig (logical) vs N_padded (physical).
"""
FMHA_MIN_SEQ = 128 # Minimum seq_len for correct FMHA kernel output
from dsv4.kernels.attention.production import dsv4_attention
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA
v = all_kv.unsqueeze(0).contiguous()
# Sinks: per-head logit bias
sinks = w.get(f"{pfx}.sinks")
sink_bias = None
if sinks is not None:
sink_bias = sinks.to(device=dev).float().reshape(n_h)
if seq_len >= FMHA_MIN_SEQ:
# Production path: 6-warp TMA multi-tile kernel
from dsv4.kernels.attention.production import dsv4_attention
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA
v = all_kv.unsqueeze(0).contiguous()
attn_out = dsv4_attention(
q=q, k=k, v=v, scale=scale,
n_comp=0, sink_bias=sink_bias,
) # (n_h, T, hd)
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
else:
# Short-sequence path: PyTorch SDPA
# TODO: Replace with fixed FMHA kernel once softmax padding is handled
k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous() # (n_h, seq_len, hd)
v_exp = k_exp.clone()
q_in = q_heads.permute(1, 0, 2) # (n_h, T, hd)
scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale
if sink_bias is not None:
sink_logits = sink_bias.reshape(n_h, 1, 1).expand(-1, T, 1)
combined = torch.cat([scores, sink_logits], dim=-1)
combined = combined - combined.max(-1, keepdim=True).values
probs = torch.softmax(combined.float(), -1).bfloat16()
attn_w = probs[..., :-1]
else:
attn_w = torch.softmax(scores.float(), -1).bfloat16()
attn_out = torch.matmul(attn_w, v_exp).permute(1, 0, 2) # (T, n_h, hd)
return attn_out
attn_out = dsv4_attention(
q=q, k=k, v=v, scale=scale,
n_comp=0, sink_bias=sink_bias,
) # (n_h, T, hd)
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
# =====================================================================