Clean production wrapper: always normalize=False + KV merge

This commit is contained in:
2026-05-27 06:51:14 +00:00
parent 8f87109f86
commit fc4172937c

View File

@@ -3,14 +3,14 @@
Wraps the CuTeDSL FMHA kernel with Python KV merge for multi-KV-tile.
Replaces vLLM's broken FlashMLA Blackwell implementation.
Supported:
Supported attention types:
- CSA (Classical Self-Attention): 32 KV entries/block, full KV cache
- HCA (Hash-based Cross-Attention): 1 KV entry/block, compressed cache
- SWA (Sliding Window Attention): ring buffer, no KV cache
Unsupported (future):
- SMEM accumulator for in-kernel multi-KV-tile (currently uses Python KV merge)
- head_dim > 256 (MLIR compilation hang at hd=512)
Limitations:
- head_dim > 256: MLIR compilation hang (known CuTeDSL issue)
- In-kernel multi-KV-tile: blocked on TMA layout matching (uses Python KV merge)
"""
import torch
import math
@@ -18,9 +18,8 @@ import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
from dsv4.model.config import DSV4Config
# Kernel cache: compiled kernels keyed by (head_dim, s_k, use_smem_p, normalize, ...)
# Kernel cache: compiled kernels keyed by config tuple
_kernel_cache: dict = {}
@@ -28,11 +27,7 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
normalize: bool = False, apply_swa_mask: bool = False,
is_causal: bool = False, n_comp: int = 0,
apply_sink_bias: bool = False) -> tuple:
"""Get or compile a kernel for the given configuration.
Returns (compiled_kernel, FmhaKernel instance).
Kernel compilation is expensive, so we cache by config.
"""
"""Get or compile a kernel for the given configuration."""
key = (head_dim, s_k, use_smem_p, normalize, apply_swa_mask, is_causal, n_comp, apply_sink_bias)
if key in _kernel_cache:
return _kernel_cache[key]
@@ -43,7 +38,6 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
n_comp=n_comp, apply_sink_bias=apply_sink_bias,
)
# Create dummy tensors for compilation
m = 128
pv_n_tile = kernel.pv_n_tile
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
@@ -51,7 +45,6 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
row_sums = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
@@ -61,7 +54,6 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
mRS = ct.from_dlpack(row_sums).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums))
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE)
_kernel_cache[key] = (compiled, kernel)
@@ -72,13 +64,13 @@ def dsv4_attention(
q: torch.Tensor, # (num_heads, seq_len, head_dim)
k: torch.Tensor, # (seq_len, head_dim) or (num_heads, seq_len, head_dim) for MQA
v: torch.Tensor, # (seq_len, head_dim) or (num_heads, seq_len, head_dim) for MQA
scale: float = None, # 1/sqrt(head_dim) if None
swa_len: int = None, # sliding window length (None = no SWA mask)
scale: float = None,
swa_len: int = None,
is_causal: bool = False,
n_comp: int = 0, # compressed KV length for D5c sink bias
sink_bias: torch.Tensor = None, # per-head logit bias for D5c
n_comp: int = 0,
sink_bias: torch.Tensor = None,
) -> torch.Tensor:
"""Production DSV4 attention using CuTeDSL FMHA kernel.
"""Production DSV4 attention using CuTeDSL FMHA kernel + Python KV merge.
Args:
q: Query tensor (num_heads, seq_len, head_dim) BF16
@@ -94,23 +86,18 @@ def dsv4_attention(
Output tensor (num_heads, seq_len, head_dim) BF16
"""
n_h, T, hd = q.shape
N = k.shape[-2] # KV sequence length
N = k.shape[-2]
scale = scale or (1.0 / math.sqrt(hd))
use_smem_p = hd > 64
apply_swa_mask = swa_len is not None
apply_sink_bias = sink_bias is not None
# Per-head launch: one kernel call per head
output = torch.zeros_like(q)
output = torch.zeros(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
for h in range(n_h):
q_h = q[h:h+1] # (1, T, hd)
k_h = k if k.dim() == 2 else k[h] # shared K (MQA) or per-head
if k_h.dim() == 2:
k_h = k_h.unsqueeze(0) # (1, N, hd)
v_h = v if v.dim() == 2 else v[h]
if v_h.dim() == 2:
v_h = v_h.unsqueeze(0)
k_h = k if k.dim() == 2 else k[h:h+1] # (1, N, hd)
v_h = v if v.dim() == 2 else v[h:h+1]
o_h = _attention_single_head(
q_h, k_h, v_h, scale=scale,
@@ -123,95 +110,6 @@ def dsv4_attention(
return output
def _attention_single_head_normalized(
q: torch.Tensor, # (1, T, hd)
k: torch.Tensor, # (1, N, hd)
v: torch.Tensor, # (1, N, hd)
scale: float,
swa_len: int = None,
is_causal: bool = False,
n_comp: int = 0,
sink_bias: torch.Tensor = None,
use_smem_p: bool = False,
) -> torch.Tensor:
"""Run FMHA for a single head with Python normalization (single KV tile)."""
_, T, hd = q.shape
N = k.shape[1]
apply_swa_mask = swa_len is not None
apply_sink_bias = sink_bias is not None
compiled, kernel = _get_or_compile_kernel(
head_dim=hd, s_k=N, use_smem_p=use_smem_p,
normalize=False, apply_swa_mask=apply_swa_mask,
is_causal=is_causal, n_comp=n_comp if n_comp > 0 else None,
apply_sink_bias=apply_sink_bias,
)
pv_n_tile = kernel.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
output_unnorm = torch.zeros(T, hd, dtype=torch.float32, device='cuda')
lse_val = None
for nt in range(n_pv_tiles):
v_start = nt * pv_n_tile
v_end = v_start + pv_n_tile
v_tile = v[0, :, v_start:v_end].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
row_sums_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
q_input = q[0].contiguous().unsqueeze(-1)
k_input = k[0].contiguous().unsqueeze(-1)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
mQ = ct.from_dlpack(q_input).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_input))
mK = ct.from_dlpack(k_input).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_input))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
compiled(mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS)
torch.cuda.synchronize()
output_unnorm[:, v_start:v_end] = c_tile[:, :, 0].float()
if nt == 0:
lse_val = lse_tensor[0, 0, 0].item()
# Normalize: O_norm = O_unnorm / row_sum
# row_sum is computed from lse: row_sum = exp(lse - row_max * ln2) ... complex
# But we have row_sums from the kernel!
# row_sums_tensor[0,0,0] has the row_sum for row 0 only.
# For per-row normalization, we'd need per-row row_sums.
# The kernel outputs row_sums[sfw_idx, 0, 0] for each of 128 rows.
# But we can also compute from the un-normalized O and the reference.
#
# Simpler: compute row_sum from the un-normalized O and the LSE.
# lse = ln(row_sum) + row_max * ln(2)
# exp(lse) = row_sum * exp(row_max * ln(2)) = row_sum * 2^row_max
# O_unnorm = P @ V where P = softmax * row_sum
# Wait, the kernel's P = exp2(S*scale - row_max) which is NOT softmax.
# softmax = P / row_sum
# O_unnorm = P @ V = softmax * row_sum @ V = O_norm * row_sum
# So: O_norm = O_unnorm / row_sum
#
# We need row_sum per row. The kernel outputs it in row_sums_tensor.
# But row_sums_tensor only has values for 128 rows, and we need all T rows.
# Actually T=128, so row_sums_tensor[0:128, 0, 0] has all rows.
#
# Let me extract per-row row_sums.
# But the kernel only writes row_sums[sfw_idx, 0, 0] for sfw_idx 0..127.
# And row_sums_tensor is (T, 1, 1) = (128, 1, 1).
# So row_sums_tensor[:, 0, 0] should have all 128 rows.
# Re-run to get row_sums (we already have them from the last call)
row_sums_per_row = row_sums_tensor[:, 0, 0].float().unsqueeze(1) # (T, 1)
row_sums_per_row = row_sums_per_row.clamp(min=1e-30)
output_norm = output_unnorm / row_sums_per_row
return output_norm.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)
def _attention_single_head(
q: torch.Tensor, # (1, T, hd)
k: torch.Tensor, # (1, N, hd)
@@ -229,24 +127,10 @@ def _attention_single_head(
apply_swa_mask = swa_len is not None
apply_sink_bias = sink_bias is not None
# Reference output
qf = q[0].float() # (T, hd)
kf = k[0].float() # (N, hd)
vf = v[0].float() # (N, hd)
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_norm = (attn_exp / attn_sum) @ vf
# Each segment is s_k=128 (one KV tile)
# Segment the KV sequence into 128-entry tiles
s_k_per_seg = 128
n_segments = (N + s_k_per_seg - 1) // s_k_per_seg
if n_segments == 1:
# Single segment: use kernel with normalize=True (in-kernel normalization)
return _attention_single_head_normalized(q, k, v, scale, swa_len, is_causal, n_comp, sink_bias, use_smem_p)
# Multi-segment: use un-normalized output + Python KV merge
compiled, kernel = _get_or_compile_kernel(
head_dim=hd, s_k=s_k_per_seg, use_smem_p=use_smem_p,
normalize=False, apply_swa_mask=apply_swa_mask,
@@ -266,7 +150,7 @@ def _attention_single_head(
k_seg = k[:, k_start:k_end]
v_seg = v[:, k_start:k_end]
# If last segment is shorter than s_k_per_seg, pad
# Pad last segment if shorter than s_k_per_seg
if k_end - k_start < s_k_per_seg:
pad_len = s_k_per_seg - (k_end - k_start)
k_seg = torch.cat([k_seg, torch.zeros(1, pad_len, hd, dtype=k.dtype, device='cuda')], dim=1)
@@ -282,11 +166,9 @@ def _attention_single_head(
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
row_sums_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
# Prepare CuTe tensors
q_input = q[0].contiguous().unsqueeze(-1) # (T, hd, 1)
k_input = k_seg[0].contiguous().unsqueeze(-1) # (s_k_per_seg, hd, 1)
q_input = q[0].contiguous().unsqueeze(-1)
k_input = k_seg[0].contiguous().unsqueeze(-1)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
mQ = ct.from_dlpack(q_input).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_input))
@@ -294,7 +176,6 @@ def _attention_single_head(
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
compiled(mQ, mK, mV, mC, stream, lse=mLSE)
torch.cuda.synchronize()
@@ -303,10 +184,7 @@ def _attention_single_head(
if nt == 0:
seg_lse[:, 0] = lse_tensor[:, 0, 0].float()
# Merge with accumulator using log-sum-exp (same as test_d1_kv_merge.py)
# Formula: O = (exp(lse_old) * O_old + exp(lse_new) * O_new) / (exp(lse_old) + exp(lse_new))
# Both O_old and O_new are un-normalized outputs from the kernel.
# This produces the correct normalized result after merge.
# Merge with accumulator using log-sum-exp
e_old = torch.exp(lse_accum)
e_new = torch.exp(seg_lse)
e_sum = e_old + e_new
@@ -314,5 +192,5 @@ def _attention_single_head(
o_accum = (e_old * o_accum + e_new * seg_o) / e_sum
lse_accum = torch.log(e_sum)
output = o_accum.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)
output = o_accum.to(torch.bfloat16).unsqueeze(0)
return output