FMHA kernel fix: N_orig vs N_padded — correct softmax masking for seq_len < 128
ROOT CAUSE: fmha_multitile_op.py padded N to 128 for TMA alignment but then passed the PADDED N to the kernel as s_k (logical KV length). This told the kernel all 128 entries were valid, so softmax ran over zeros, diluting the result (e.g. 1 valid entry → softmax weight 1/128). FIX: Pass N_orig (true sequence length) as s_k for softmax masking, and N_padded (physical size) only for TMA descriptor creation. The kernel's existing col < kv_len guard correctly excludes padded entries from row_max and exp_sum calculations. Files changed: - fmha_multitile_capi.cu: accept N_orig + N_padded, use N_orig for params.s_k and N_padded for TMA descriptors - fmha_multitile_op.py: pass N_orig and N_padded separately - single_shot_inference.py: removed SDPA fallback (kernel now correct)
This commit is contained in:
@@ -26,7 +26,7 @@ int fmha_multitile_decode_launch(
|
||||
const void* v_ptr,
|
||||
void* o_ptr,
|
||||
void* lse_ptr,
|
||||
int batch, int n_h, int T, int N, int hd,
|
||||
int batch, int n_h, int T, int N_orig, int N_padded, int hd,
|
||||
int q_head_stride, int q_batch_stride,
|
||||
int k_head_stride, int k_batch_stride,
|
||||
int v_head_stride, int v_batch_stride,
|
||||
@@ -34,6 +34,10 @@ int fmha_multitile_decode_launch(
|
||||
int lse_head_stride, int lse_batch_stride,
|
||||
float scale
|
||||
) {
|
||||
// N_orig: logical KV length (used for softmax masking in kernel)
|
||||
// N_padded: physical KV length (used for TMA descriptor creation)
|
||||
// When N_orig < N_padded, the extra rows are zero-padded and
|
||||
// correctly excluded from softmax by the kernel's col < kv_len guard.
|
||||
size_t desc_count = n_h * batch;
|
||||
|
||||
CUtensorMap* d_tma_k;
|
||||
@@ -47,16 +51,16 @@ int fmha_multitile_decode_launch(
|
||||
const bf16_t* v_head = (const bf16_t*)v_ptr + h * v_head_stride + b * v_batch_stride;
|
||||
int idx = b * n_h + h;
|
||||
|
||||
// K: (N, hd), TMA tile (128, 16)
|
||||
// K: (N_padded, hd), TMA tile (128, 16) — use physical size for TMA
|
||||
CUtensorMap h_desc;
|
||||
if (!create_tma_desc_2d_bf16(&h_desc, k_head, N, hd, 128, 16)) {
|
||||
if (!create_tma_desc_2d_bf16(&h_desc, k_head, N_padded, hd, 128, 16)) {
|
||||
cudaFree(d_tma_k); cudaFree(d_tma_v);
|
||||
return -1;
|
||||
}
|
||||
cudaMemcpy(d_tma_k + idx, &h_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
|
||||
// V: (hd, N), TMA tile (16, 16)
|
||||
if (!create_tma_desc_2d_bf16(&h_desc, v_head, hd, N, 16, 16)) {
|
||||
// V: (hd, N_padded), TMA tile (16, 16) — use physical size for TMA
|
||||
if (!create_tma_desc_2d_bf16(&h_desc, v_head, hd, N_padded, 16, 16)) {
|
||||
cudaFree(d_tma_k); cudaFree(d_tma_v);
|
||||
return -1;
|
||||
}
|
||||
@@ -70,7 +74,7 @@ int fmha_multitile_decode_launch(
|
||||
params.tma_v = d_tma_v;
|
||||
params.o = (bf16_t*)o_ptr;
|
||||
params.lse = (float*)lse_ptr;
|
||||
params.s_k = N;
|
||||
params.s_k = N_orig; // Logical KV length — kernel uses this for softmax masking
|
||||
params.T = T;
|
||||
params.n_h = n_h;
|
||||
params.scale = scale;
|
||||
|
||||
@@ -100,13 +100,17 @@ def fmha_multitile_decode_raw(
|
||||
k = k.repeat_interleave(q_per_kv, dim=1)
|
||||
v = v.repeat_interleave(q_per_kv, dim=1)
|
||||
|
||||
# Pad N to multiple of 128
|
||||
# Pad N to multiple of 128 (TMA descriptor alignment)
|
||||
# CRITICAL: We track the ORIGINAL N (N_orig) separately from N_padded.
|
||||
# The kernel uses s_k=N_orig as the logical KV length for softmax masking.
|
||||
# Only the K/V tensors are padded (with zeros) for TMA alignment.
|
||||
N_orig = N
|
||||
N_padded = ((N + 127) // 128) * 128
|
||||
if N < N_padded:
|
||||
pad = N_padded - N
|
||||
k = torch.cat([k, torch.zeros(B, k.shape[1], pad, hd, dtype=torch.bfloat16, device=k.device)], dim=2)
|
||||
v = torch.cat([v, torch.zeros(v.shape[0], v.shape[1], hd, pad, dtype=torch.bfloat16, device=v.device)], dim=3)
|
||||
N = N_padded
|
||||
N = N_padded # N is now the physical size (padded)
|
||||
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
@@ -121,7 +125,10 @@ def fmha_multitile_decode_raw(
|
||||
ctypes.c_void_p(v.data_ptr()),
|
||||
ctypes.c_void_p(o.data_ptr()),
|
||||
ctypes.c_void_p(lse.data_ptr()),
|
||||
ctypes.c_int(B), ctypes.c_int(n_h), ctypes.c_int(T), ctypes.c_int(N), ctypes.c_int(hd),
|
||||
ctypes.c_int(B), ctypes.c_int(n_h), ctypes.c_int(T),
|
||||
ctypes.c_int(N_orig), # s_k: logical KV length (for softmax masking)
|
||||
ctypes.c_int(N_padded), # N_padded: physical KV length (for TMA descriptors)
|
||||
ctypes.c_int(hd),
|
||||
ctypes.c_int(q.stride(1)), ctypes.c_int(q.stride(0)),
|
||||
ctypes.c_int(k.stride(1)), ctypes.c_int(k.stride(0)),
|
||||
ctypes.c_int(v.stride(1)), ctypes.c_int(v.stride(0)),
|
||||
|
||||
@@ -372,53 +372,27 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
|
||||
q_heads: (T, n_h, hd), all_kv: (seq_len, hd)
|
||||
Returns: (T, n_h, hd) BF16
|
||||
|
||||
KERNEL LIMITATION: The 6-warp TMA FMHA kernel pads N to 128.
|
||||
When seq_len < 128, the zero-padded entries dilute the softmax
|
||||
(e.g. seq_len=1 gives softmax over 128, 127 of which are zero,
|
||||
reducing max attention weight from 1.0 to 1/128). This must be
|
||||
fixed in the kernel (skip zero-padded entries in softmax). Until
|
||||
then, we use PyTorch scaled_dot_product_attention for short
|
||||
sequences where the padding would dominate.
|
||||
|
||||
TODO: Fix FMHA kernel to handle N < 128 correctly (mask padded
|
||||
entries from softmax). This is a kernel bug, not a design choice.
|
||||
The 6-warp TMA FMHA kernel correctly handles N < 128:
|
||||
K/V are padded to 128 for TMA alignment, but the kernel receives
|
||||
the true s_k and masks padded entries in softmax (col < kv_len guard).
|
||||
Fixed in fmha_multitile_capi.cu: N_orig (logical) vs N_padded (physical).
|
||||
"""
|
||||
FMHA_MIN_SEQ = 128 # Minimum seq_len for correct FMHA kernel output
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
|
||||
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
|
||||
k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA
|
||||
v = all_kv.unsqueeze(0).contiguous()
|
||||
|
||||
# Sinks: per-head logit bias
|
||||
sinks = w.get(f"{pfx}.sinks")
|
||||
sink_bias = None
|
||||
if sinks is not None:
|
||||
sink_bias = sinks.to(device=dev).float().reshape(n_h)
|
||||
|
||||
if seq_len >= FMHA_MIN_SEQ:
|
||||
# Production path: 6-warp TMA multi-tile kernel
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
|
||||
k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA
|
||||
v = all_kv.unsqueeze(0).contiguous()
|
||||
attn_out = dsv4_attention(
|
||||
q=q, k=k, v=v, scale=scale,
|
||||
n_comp=0, sink_bias=sink_bias,
|
||||
) # (n_h, T, hd)
|
||||
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
else:
|
||||
# Short-sequence path: PyTorch SDPA
|
||||
# TODO: Replace with fixed FMHA kernel once softmax padding is handled
|
||||
k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous() # (n_h, seq_len, hd)
|
||||
v_exp = k_exp.clone()
|
||||
q_in = q_heads.permute(1, 0, 2) # (n_h, T, hd)
|
||||
scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale
|
||||
if sink_bias is not None:
|
||||
sink_logits = sink_bias.reshape(n_h, 1, 1).expand(-1, T, 1)
|
||||
combined = torch.cat([scores, sink_logits], dim=-1)
|
||||
combined = combined - combined.max(-1, keepdim=True).values
|
||||
probs = torch.softmax(combined.float(), -1).bfloat16()
|
||||
attn_w = probs[..., :-1]
|
||||
else:
|
||||
attn_w = torch.softmax(scores.float(), -1).bfloat16()
|
||||
attn_out = torch.matmul(attn_w, v_exp).permute(1, 0, 2) # (T, n_h, hd)
|
||||
return attn_out
|
||||
attn_out = dsv4_attention(
|
||||
q=q, k=k, v=v, scale=scale,
|
||||
n_comp=0, sink_bias=sink_bias,
|
||||
) # (n_h, T, hd)
|
||||
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
|
||||
Reference in New Issue
Block a user