diff --git a/dsv4/kernels/attention/fmha_multitile_capi.cu b/dsv4/kernels/attention/fmha_multitile_capi.cu index cc3fdeed..f78b0b7e 100644 --- a/dsv4/kernels/attention/fmha_multitile_capi.cu +++ b/dsv4/kernels/attention/fmha_multitile_capi.cu @@ -26,7 +26,7 @@ int fmha_multitile_decode_launch( const void* v_ptr, void* o_ptr, void* lse_ptr, - int batch, int n_h, int T, int N, int hd, + int batch, int n_h, int T, int N_orig, int N_padded, int hd, int q_head_stride, int q_batch_stride, int k_head_stride, int k_batch_stride, int v_head_stride, int v_batch_stride, @@ -34,6 +34,10 @@ int fmha_multitile_decode_launch( int lse_head_stride, int lse_batch_stride, float scale ) { + // N_orig: logical KV length (used for softmax masking in kernel) + // N_padded: physical KV length (used for TMA descriptor creation) + // When N_orig < N_padded, the extra rows are zero-padded and + // correctly excluded from softmax by the kernel's col < kv_len guard. size_t desc_count = n_h * batch; CUtensorMap* d_tma_k; @@ -47,16 +51,16 @@ int fmha_multitile_decode_launch( const bf16_t* v_head = (const bf16_t*)v_ptr + h * v_head_stride + b * v_batch_stride; int idx = b * n_h + h; - // K: (N, hd), TMA tile (128, 16) + // K: (N_padded, hd), TMA tile (128, 16) — use physical size for TMA CUtensorMap h_desc; - if (!create_tma_desc_2d_bf16(&h_desc, k_head, N, hd, 128, 16)) { + if (!create_tma_desc_2d_bf16(&h_desc, k_head, N_padded, hd, 128, 16)) { cudaFree(d_tma_k); cudaFree(d_tma_v); return -1; } cudaMemcpy(d_tma_k + idx, &h_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice); - // V: (hd, N), TMA tile (16, 16) - if (!create_tma_desc_2d_bf16(&h_desc, v_head, hd, N, 16, 16)) { + // V: (hd, N_padded), TMA tile (16, 16) — use physical size for TMA + if (!create_tma_desc_2d_bf16(&h_desc, v_head, hd, N_padded, 16, 16)) { cudaFree(d_tma_k); cudaFree(d_tma_v); return -1; } @@ -70,7 +74,7 @@ int fmha_multitile_decode_launch( params.tma_v = d_tma_v; params.o = (bf16_t*)o_ptr; params.lse = (float*)lse_ptr; - params.s_k = N; + params.s_k = N_orig; // Logical KV length — kernel uses this for softmax masking params.T = T; params.n_h = n_h; params.scale = scale; diff --git a/dsv4/kernels/attention/fmha_multitile_op.py b/dsv4/kernels/attention/fmha_multitile_op.py index f0e6c9b5..f67d319d 100644 --- a/dsv4/kernels/attention/fmha_multitile_op.py +++ b/dsv4/kernels/attention/fmha_multitile_op.py @@ -100,13 +100,17 @@ def fmha_multitile_decode_raw( k = k.repeat_interleave(q_per_kv, dim=1) v = v.repeat_interleave(q_per_kv, dim=1) - # Pad N to multiple of 128 + # Pad N to multiple of 128 (TMA descriptor alignment) + # CRITICAL: We track the ORIGINAL N (N_orig) separately from N_padded. + # The kernel uses s_k=N_orig as the logical KV length for softmax masking. + # Only the K/V tensors are padded (with zeros) for TMA alignment. + N_orig = N N_padded = ((N + 127) // 128) * 128 if N < N_padded: pad = N_padded - N k = torch.cat([k, torch.zeros(B, k.shape[1], pad, hd, dtype=torch.bfloat16, device=k.device)], dim=2) v = torch.cat([v, torch.zeros(v.shape[0], v.shape[1], hd, pad, dtype=torch.bfloat16, device=v.device)], dim=3) - N = N_padded + N = N_padded # N is now the physical size (padded) k = k.contiguous() v = v.contiguous() @@ -121,7 +125,10 @@ def fmha_multitile_decode_raw( ctypes.c_void_p(v.data_ptr()), ctypes.c_void_p(o.data_ptr()), ctypes.c_void_p(lse.data_ptr()), - ctypes.c_int(B), ctypes.c_int(n_h), ctypes.c_int(T), ctypes.c_int(N), ctypes.c_int(hd), + ctypes.c_int(B), ctypes.c_int(n_h), ctypes.c_int(T), + ctypes.c_int(N_orig), # s_k: logical KV length (for softmax masking) + ctypes.c_int(N_padded), # N_padded: physical KV length (for TMA descriptors) + ctypes.c_int(hd), ctypes.c_int(q.stride(1)), ctypes.c_int(q.stride(0)), ctypes.c_int(k.stride(1)), ctypes.c_int(k.stride(0)), ctypes.c_int(v.stride(1)), ctypes.c_int(v.stride(0)), diff --git a/single_shot_inference.py b/single_shot_inference.py index db593cb9..1e056759 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -372,53 +372,27 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w q_heads: (T, n_h, hd), all_kv: (seq_len, hd) Returns: (T, n_h, hd) BF16 - KERNEL LIMITATION: The 6-warp TMA FMHA kernel pads N to 128. - When seq_len < 128, the zero-padded entries dilute the softmax - (e.g. seq_len=1 gives softmax over 128, 127 of which are zero, - reducing max attention weight from 1.0 to 1/128). This must be - fixed in the kernel (skip zero-padded entries in softmax). Until - then, we use PyTorch scaled_dot_product_attention for short - sequences where the padding would dominate. - - TODO: Fix FMHA kernel to handle N < 128 correctly (mask padded - entries from softmax). This is a kernel bug, not a design choice. + The 6-warp TMA FMHA kernel correctly handles N < 128: + K/V are padded to 128 for TMA alignment, but the kernel receives + the true s_k and masks padded entries in softmax (col < kv_len guard). + Fixed in fmha_multitile_capi.cu: N_orig (logical) vs N_padded (physical). """ - FMHA_MIN_SEQ = 128 # Minimum seq_len for correct FMHA kernel output + from dsv4.kernels.attention.production import dsv4_attention + + q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd) + k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA + v = all_kv.unsqueeze(0).contiguous() - # Sinks: per-head logit bias sinks = w.get(f"{pfx}.sinks") sink_bias = None if sinks is not None: sink_bias = sinks.to(device=dev).float().reshape(n_h) - if seq_len >= FMHA_MIN_SEQ: - # Production path: 6-warp TMA multi-tile kernel - from dsv4.kernels.attention.production import dsv4_attention - q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd) - k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA - v = all_kv.unsqueeze(0).contiguous() - attn_out = dsv4_attention( - q=q, k=k, v=v, scale=scale, - n_comp=0, sink_bias=sink_bias, - ) # (n_h, T, hd) - return attn_out.permute(1, 0, 2) # (T, n_h, hd) - else: - # Short-sequence path: PyTorch SDPA - # TODO: Replace with fixed FMHA kernel once softmax padding is handled - k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous() # (n_h, seq_len, hd) - v_exp = k_exp.clone() - q_in = q_heads.permute(1, 0, 2) # (n_h, T, hd) - scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale - if sink_bias is not None: - sink_logits = sink_bias.reshape(n_h, 1, 1).expand(-1, T, 1) - combined = torch.cat([scores, sink_logits], dim=-1) - combined = combined - combined.max(-1, keepdim=True).values - probs = torch.softmax(combined.float(), -1).bfloat16() - attn_w = probs[..., :-1] - else: - attn_w = torch.softmax(scores.float(), -1).bfloat16() - attn_out = torch.matmul(attn_w, v_exp).permute(1, 0, 2) # (T, n_h, hd) - return attn_out + attn_out = dsv4_attention( + q=q, k=k, v=v, scale=scale, + n_comp=0, sink_bias=sink_bias, + ) # (n_h, T, hd) + return attn_out.permute(1, 0, 2) # (T, n_h, hd) # =====================================================================