Clean production wrapper: always normalize=False + KV merge
This commit is contained in:
@@ -3,14 +3,14 @@
|
||||
Wraps the CuTeDSL FMHA kernel with Python KV merge for multi-KV-tile.
|
||||
Replaces vLLM's broken FlashMLA Blackwell implementation.
|
||||
|
||||
Supported:
|
||||
Supported attention types:
|
||||
- CSA (Classical Self-Attention): 32 KV entries/block, full KV cache
|
||||
- HCA (Hash-based Cross-Attention): 1 KV entry/block, compressed cache
|
||||
- SWA (Sliding Window Attention): ring buffer, no KV cache
|
||||
|
||||
Unsupported (future):
|
||||
- SMEM accumulator for in-kernel multi-KV-tile (currently uses Python KV merge)
|
||||
- head_dim > 256 (MLIR compilation hang at hd=512)
|
||||
Limitations:
|
||||
- head_dim > 256: MLIR compilation hang (known CuTeDSL issue)
|
||||
- In-kernel multi-KV-tile: blocked on TMA layout matching (uses Python KV merge)
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
@@ -18,9 +18,8 @@ import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
from dsv4.model.config import DSV4Config
|
||||
|
||||
# Kernel cache: compiled kernels keyed by (head_dim, s_k, use_smem_p, normalize, ...)
|
||||
# Kernel cache: compiled kernels keyed by config tuple
|
||||
_kernel_cache: dict = {}
|
||||
|
||||
|
||||
@@ -28,11 +27,7 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
|
||||
normalize: bool = False, apply_swa_mask: bool = False,
|
||||
is_causal: bool = False, n_comp: int = 0,
|
||||
apply_sink_bias: bool = False) -> tuple:
|
||||
"""Get or compile a kernel for the given configuration.
|
||||
|
||||
Returns (compiled_kernel, FmhaKernel instance).
|
||||
Kernel compilation is expensive, so we cache by config.
|
||||
"""
|
||||
"""Get or compile a kernel for the given configuration."""
|
||||
key = (head_dim, s_k, use_smem_p, normalize, apply_swa_mask, is_causal, n_comp, apply_sink_bias)
|
||||
if key in _kernel_cache:
|
||||
return _kernel_cache[key]
|
||||
@@ -43,7 +38,6 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
|
||||
n_comp=n_comp, apply_sink_bias=apply_sink_bias,
|
||||
)
|
||||
|
||||
# Create dummy tensors for compilation
|
||||
m = 128
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
@@ -51,7 +45,6 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
|
||||
v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
row_sums = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
@@ -61,7 +54,6 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False,
|
||||
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
|
||||
mRS = ct.from_dlpack(row_sums).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums))
|
||||
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE)
|
||||
_kernel_cache[key] = (compiled, kernel)
|
||||
@@ -72,13 +64,13 @@ def dsv4_attention(
|
||||
q: torch.Tensor, # (num_heads, seq_len, head_dim)
|
||||
k: torch.Tensor, # (seq_len, head_dim) or (num_heads, seq_len, head_dim) for MQA
|
||||
v: torch.Tensor, # (seq_len, head_dim) or (num_heads, seq_len, head_dim) for MQA
|
||||
scale: float = None, # 1/sqrt(head_dim) if None
|
||||
swa_len: int = None, # sliding window length (None = no SWA mask)
|
||||
scale: float = None,
|
||||
swa_len: int = None,
|
||||
is_causal: bool = False,
|
||||
n_comp: int = 0, # compressed KV length for D5c sink bias
|
||||
sink_bias: torch.Tensor = None, # per-head logit bias for D5c
|
||||
n_comp: int = 0,
|
||||
sink_bias: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""Production DSV4 attention using CuTeDSL FMHA kernel.
|
||||
"""Production DSV4 attention using CuTeDSL FMHA kernel + Python KV merge.
|
||||
|
||||
Args:
|
||||
q: Query tensor (num_heads, seq_len, head_dim) BF16
|
||||
@@ -94,23 +86,18 @@ def dsv4_attention(
|
||||
Output tensor (num_heads, seq_len, head_dim) BF16
|
||||
"""
|
||||
n_h, T, hd = q.shape
|
||||
N = k.shape[-2] # KV sequence length
|
||||
N = k.shape[-2]
|
||||
scale = scale or (1.0 / math.sqrt(hd))
|
||||
use_smem_p = hd > 64
|
||||
apply_swa_mask = swa_len is not None
|
||||
apply_sink_bias = sink_bias is not None
|
||||
|
||||
# Per-head launch: one kernel call per head
|
||||
output = torch.zeros_like(q)
|
||||
output = torch.zeros(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
for h in range(n_h):
|
||||
q_h = q[h:h+1] # (1, T, hd)
|
||||
k_h = k if k.dim() == 2 else k[h] # shared K (MQA) or per-head
|
||||
if k_h.dim() == 2:
|
||||
k_h = k_h.unsqueeze(0) # (1, N, hd)
|
||||
v_h = v if v.dim() == 2 else v[h]
|
||||
if v_h.dim() == 2:
|
||||
v_h = v_h.unsqueeze(0)
|
||||
k_h = k if k.dim() == 2 else k[h:h+1] # (1, N, hd)
|
||||
v_h = v if v.dim() == 2 else v[h:h+1]
|
||||
|
||||
o_h = _attention_single_head(
|
||||
q_h, k_h, v_h, scale=scale,
|
||||
@@ -123,95 +110,6 @@ def dsv4_attention(
|
||||
return output
|
||||
|
||||
|
||||
def _attention_single_head_normalized(
|
||||
q: torch.Tensor, # (1, T, hd)
|
||||
k: torch.Tensor, # (1, N, hd)
|
||||
v: torch.Tensor, # (1, N, hd)
|
||||
scale: float,
|
||||
swa_len: int = None,
|
||||
is_causal: bool = False,
|
||||
n_comp: int = 0,
|
||||
sink_bias: torch.Tensor = None,
|
||||
use_smem_p: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Run FMHA for a single head with Python normalization (single KV tile)."""
|
||||
_, T, hd = q.shape
|
||||
N = k.shape[1]
|
||||
apply_swa_mask = swa_len is not None
|
||||
apply_sink_bias = sink_bias is not None
|
||||
|
||||
compiled, kernel = _get_or_compile_kernel(
|
||||
head_dim=hd, s_k=N, use_smem_p=use_smem_p,
|
||||
normalize=False, apply_swa_mask=apply_swa_mask,
|
||||
is_causal=is_causal, n_comp=n_comp if n_comp > 0 else None,
|
||||
apply_sink_bias=apply_sink_bias,
|
||||
)
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
n_pv_tiles = kernel.n_pv_tiles
|
||||
|
||||
output_unnorm = torch.zeros(T, hd, dtype=torch.float32, device='cuda')
|
||||
lse_val = None
|
||||
|
||||
for nt in range(n_pv_tiles):
|
||||
v_start = nt * pv_n_tile
|
||||
v_end = v_start + pv_n_tile
|
||||
v_tile = v[0, :, v_start:v_end].contiguous()
|
||||
v_kernel = v_tile.unsqueeze(-1)
|
||||
c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
|
||||
row_sums_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
q_input = q[0].contiguous().unsqueeze(-1)
|
||||
k_input = k[0].contiguous().unsqueeze(-1)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
mQ = ct.from_dlpack(q_input).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_input))
|
||||
mK = ct.from_dlpack(k_input).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_input))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
|
||||
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
|
||||
|
||||
compiled(mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
output_unnorm[:, v_start:v_end] = c_tile[:, :, 0].float()
|
||||
if nt == 0:
|
||||
lse_val = lse_tensor[0, 0, 0].item()
|
||||
|
||||
# Normalize: O_norm = O_unnorm / row_sum
|
||||
# row_sum is computed from lse: row_sum = exp(lse - row_max * ln2) ... complex
|
||||
# But we have row_sums from the kernel!
|
||||
# row_sums_tensor[0,0,0] has the row_sum for row 0 only.
|
||||
# For per-row normalization, we'd need per-row row_sums.
|
||||
# The kernel outputs row_sums[sfw_idx, 0, 0] for each of 128 rows.
|
||||
# But we can also compute from the un-normalized O and the reference.
|
||||
#
|
||||
# Simpler: compute row_sum from the un-normalized O and the LSE.
|
||||
# lse = ln(row_sum) + row_max * ln(2)
|
||||
# exp(lse) = row_sum * exp(row_max * ln(2)) = row_sum * 2^row_max
|
||||
# O_unnorm = P @ V where P = softmax * row_sum
|
||||
# Wait, the kernel's P = exp2(S*scale - row_max) which is NOT softmax.
|
||||
# softmax = P / row_sum
|
||||
# O_unnorm = P @ V = softmax * row_sum @ V = O_norm * row_sum
|
||||
# So: O_norm = O_unnorm / row_sum
|
||||
#
|
||||
# We need row_sum per row. The kernel outputs it in row_sums_tensor.
|
||||
# But row_sums_tensor only has values for 128 rows, and we need all T rows.
|
||||
# Actually T=128, so row_sums_tensor[0:128, 0, 0] has all rows.
|
||||
#
|
||||
# Let me extract per-row row_sums.
|
||||
# But the kernel only writes row_sums[sfw_idx, 0, 0] for sfw_idx 0..127.
|
||||
# And row_sums_tensor is (T, 1, 1) = (128, 1, 1).
|
||||
# So row_sums_tensor[:, 0, 0] should have all 128 rows.
|
||||
|
||||
# Re-run to get row_sums (we already have them from the last call)
|
||||
row_sums_per_row = row_sums_tensor[:, 0, 0].float().unsqueeze(1) # (T, 1)
|
||||
row_sums_per_row = row_sums_per_row.clamp(min=1e-30)
|
||||
output_norm = output_unnorm / row_sums_per_row
|
||||
return output_norm.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)
|
||||
|
||||
|
||||
def _attention_single_head(
|
||||
q: torch.Tensor, # (1, T, hd)
|
||||
k: torch.Tensor, # (1, N, hd)
|
||||
@@ -229,24 +127,10 @@ def _attention_single_head(
|
||||
apply_swa_mask = swa_len is not None
|
||||
apply_sink_bias = sink_bias is not None
|
||||
|
||||
# Reference output
|
||||
qf = q[0].float() # (T, hd)
|
||||
kf = k[0].float() # (N, hd)
|
||||
vf = v[0].float() # (N, hd)
|
||||
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
|
||||
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
|
||||
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
|
||||
ref_norm = (attn_exp / attn_sum) @ vf
|
||||
|
||||
# Each segment is s_k=128 (one KV tile)
|
||||
# Segment the KV sequence into 128-entry tiles
|
||||
s_k_per_seg = 128
|
||||
n_segments = (N + s_k_per_seg - 1) // s_k_per_seg
|
||||
|
||||
if n_segments == 1:
|
||||
# Single segment: use kernel with normalize=True (in-kernel normalization)
|
||||
return _attention_single_head_normalized(q, k, v, scale, swa_len, is_causal, n_comp, sink_bias, use_smem_p)
|
||||
|
||||
# Multi-segment: use un-normalized output + Python KV merge
|
||||
compiled, kernel = _get_or_compile_kernel(
|
||||
head_dim=hd, s_k=s_k_per_seg, use_smem_p=use_smem_p,
|
||||
normalize=False, apply_swa_mask=apply_swa_mask,
|
||||
@@ -266,7 +150,7 @@ def _attention_single_head(
|
||||
k_seg = k[:, k_start:k_end]
|
||||
v_seg = v[:, k_start:k_end]
|
||||
|
||||
# If last segment is shorter than s_k_per_seg, pad
|
||||
# Pad last segment if shorter than s_k_per_seg
|
||||
if k_end - k_start < s_k_per_seg:
|
||||
pad_len = s_k_per_seg - (k_end - k_start)
|
||||
k_seg = torch.cat([k_seg, torch.zeros(1, pad_len, hd, dtype=k.dtype, device='cuda')], dim=1)
|
||||
@@ -282,11 +166,9 @@ def _attention_single_head(
|
||||
v_kernel = v_tile.unsqueeze(-1)
|
||||
c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
|
||||
row_sums_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
# Prepare CuTe tensors
|
||||
q_input = q[0].contiguous().unsqueeze(-1) # (T, hd, 1)
|
||||
k_input = k_seg[0].contiguous().unsqueeze(-1) # (s_k_per_seg, hd, 1)
|
||||
q_input = q[0].contiguous().unsqueeze(-1)
|
||||
k_input = k_seg[0].contiguous().unsqueeze(-1)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
mQ = ct.from_dlpack(q_input).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_input))
|
||||
@@ -294,7 +176,6 @@ def _attention_single_head(
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
|
||||
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
|
||||
|
||||
compiled(mQ, mK, mV, mC, stream, lse=mLSE)
|
||||
torch.cuda.synchronize()
|
||||
@@ -303,10 +184,7 @@ def _attention_single_head(
|
||||
if nt == 0:
|
||||
seg_lse[:, 0] = lse_tensor[:, 0, 0].float()
|
||||
|
||||
# Merge with accumulator using log-sum-exp (same as test_d1_kv_merge.py)
|
||||
# Formula: O = (exp(lse_old) * O_old + exp(lse_new) * O_new) / (exp(lse_old) + exp(lse_new))
|
||||
# Both O_old and O_new are un-normalized outputs from the kernel.
|
||||
# This produces the correct normalized result after merge.
|
||||
# Merge with accumulator using log-sum-exp
|
||||
e_old = torch.exp(lse_accum)
|
||||
e_new = torch.exp(seg_lse)
|
||||
e_sum = e_old + e_new
|
||||
@@ -314,5 +192,5 @@ def _attention_single_head(
|
||||
o_accum = (e_old * o_accum + e_new * seg_o) / e_sum
|
||||
lse_accum = torch.log(e_sum)
|
||||
|
||||
output = o_accum.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)
|
||||
output = o_accum.to(torch.bfloat16).unsqueeze(0)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user