From fc4172937cb721f2d17cfeddfe067c4b88d527ba Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 27 May 2026 06:51:14 +0000 Subject: [PATCH] Clean production wrapper: always normalize=False + KV merge --- dsv4/kernels/attention/production.py | 164 ++++----------------------- 1 file changed, 21 insertions(+), 143 deletions(-) diff --git a/dsv4/kernels/attention/production.py b/dsv4/kernels/attention/production.py index 81d2a282..6dea0070 100644 --- a/dsv4/kernels/attention/production.py +++ b/dsv4/kernels/attention/production.py @@ -3,14 +3,14 @@ Wraps the CuTeDSL FMHA kernel with Python KV merge for multi-KV-tile. Replaces vLLM's broken FlashMLA Blackwell implementation. -Supported: +Supported attention types: - CSA (Classical Self-Attention): 32 KV entries/block, full KV cache - HCA (Hash-based Cross-Attention): 1 KV entry/block, compressed cache - SWA (Sliding Window Attention): ring buffer, no KV cache -Unsupported (future): -- SMEM accumulator for in-kernel multi-KV-tile (currently uses Python KV merge) -- head_dim > 256 (MLIR compilation hang at hd=512) +Limitations: +- head_dim > 256: MLIR compilation hang (known CuTeDSL issue) +- In-kernel multi-KV-tile: blocked on TMA layout matching (uses Python KV merge) """ import torch import math @@ -18,9 +18,8 @@ import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from dsv4.kernels.attention.fmha import FmhaKernel -from dsv4.model.config import DSV4Config -# Kernel cache: compiled kernels keyed by (head_dim, s_k, use_smem_p, normalize, ...) +# Kernel cache: compiled kernels keyed by config tuple _kernel_cache: dict = {} @@ -28,11 +27,7 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False, normalize: bool = False, apply_swa_mask: bool = False, is_causal: bool = False, n_comp: int = 0, apply_sink_bias: bool = False) -> tuple: - """Get or compile a kernel for the given configuration. - - Returns (compiled_kernel, FmhaKernel instance). - Kernel compilation is expensive, so we cache by config. - """ + """Get or compile a kernel for the given configuration.""" key = (head_dim, s_k, use_smem_p, normalize, apply_swa_mask, is_causal, n_comp, apply_sink_bias) if key in _kernel_cache: return _kernel_cache[key] @@ -43,7 +38,6 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False, n_comp=n_comp, apply_sink_bias=apply_sink_bias, ) - # Create dummy tensors for compilation m = 128 pv_n_tile = kernel.pv_n_tile q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') @@ -51,7 +45,6 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False, v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda') c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') - row_sums = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -61,7 +54,6 @@ def _get_or_compile_kernel(head_dim: int, s_k: int, use_smem_p: bool = False, mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse)) - mRS = ct.from_dlpack(row_sums).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums)) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE) _kernel_cache[key] = (compiled, kernel) @@ -72,13 +64,13 @@ def dsv4_attention( q: torch.Tensor, # (num_heads, seq_len, head_dim) k: torch.Tensor, # (seq_len, head_dim) or (num_heads, seq_len, head_dim) for MQA v: torch.Tensor, # (seq_len, head_dim) or (num_heads, seq_len, head_dim) for MQA - scale: float = None, # 1/sqrt(head_dim) if None - swa_len: int = None, # sliding window length (None = no SWA mask) + scale: float = None, + swa_len: int = None, is_causal: bool = False, - n_comp: int = 0, # compressed KV length for D5c sink bias - sink_bias: torch.Tensor = None, # per-head logit bias for D5c + n_comp: int = 0, + sink_bias: torch.Tensor = None, ) -> torch.Tensor: - """Production DSV4 attention using CuTeDSL FMHA kernel. + """Production DSV4 attention using CuTeDSL FMHA kernel + Python KV merge. Args: q: Query tensor (num_heads, seq_len, head_dim) BF16 @@ -94,23 +86,18 @@ def dsv4_attention( Output tensor (num_heads, seq_len, head_dim) BF16 """ n_h, T, hd = q.shape - N = k.shape[-2] # KV sequence length + N = k.shape[-2] scale = scale or (1.0 / math.sqrt(hd)) use_smem_p = hd > 64 apply_swa_mask = swa_len is not None apply_sink_bias = sink_bias is not None - # Per-head launch: one kernel call per head - output = torch.zeros_like(q) + output = torch.zeros(n_h, T, hd, dtype=torch.bfloat16, device='cuda') for h in range(n_h): q_h = q[h:h+1] # (1, T, hd) - k_h = k if k.dim() == 2 else k[h] # shared K (MQA) or per-head - if k_h.dim() == 2: - k_h = k_h.unsqueeze(0) # (1, N, hd) - v_h = v if v.dim() == 2 else v[h] - if v_h.dim() == 2: - v_h = v_h.unsqueeze(0) + k_h = k if k.dim() == 2 else k[h:h+1] # (1, N, hd) + v_h = v if v.dim() == 2 else v[h:h+1] o_h = _attention_single_head( q_h, k_h, v_h, scale=scale, @@ -123,95 +110,6 @@ def dsv4_attention( return output -def _attention_single_head_normalized( - q: torch.Tensor, # (1, T, hd) - k: torch.Tensor, # (1, N, hd) - v: torch.Tensor, # (1, N, hd) - scale: float, - swa_len: int = None, - is_causal: bool = False, - n_comp: int = 0, - sink_bias: torch.Tensor = None, - use_smem_p: bool = False, -) -> torch.Tensor: - """Run FMHA for a single head with Python normalization (single KV tile).""" - _, T, hd = q.shape - N = k.shape[1] - apply_swa_mask = swa_len is not None - apply_sink_bias = sink_bias is not None - - compiled, kernel = _get_or_compile_kernel( - head_dim=hd, s_k=N, use_smem_p=use_smem_p, - normalize=False, apply_swa_mask=apply_swa_mask, - is_causal=is_causal, n_comp=n_comp if n_comp > 0 else None, - apply_sink_bias=apply_sink_bias, - ) - pv_n_tile = kernel.pv_n_tile - n_pv_tiles = kernel.n_pv_tiles - - output_unnorm = torch.zeros(T, hd, dtype=torch.float32, device='cuda') - lse_val = None - - for nt in range(n_pv_tiles): - v_start = nt * pv_n_tile - v_end = v_start + pv_n_tile - v_tile = v[0, :, v_start:v_end].contiguous() - v_kernel = v_tile.unsqueeze(-1) - c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') - lse_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda') - row_sums_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda') - - q_input = q[0].contiguous().unsqueeze(-1) - k_input = k[0].contiguous().unsqueeze(-1) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - mQ = ct.from_dlpack(q_input).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_input)) - mK = ct.from_dlpack(k_input).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_input)) - mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) - mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) - mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) - mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor)) - - compiled(mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS) - torch.cuda.synchronize() - - output_unnorm[:, v_start:v_end] = c_tile[:, :, 0].float() - if nt == 0: - lse_val = lse_tensor[0, 0, 0].item() - - # Normalize: O_norm = O_unnorm / row_sum - # row_sum is computed from lse: row_sum = exp(lse - row_max * ln2) ... complex - # But we have row_sums from the kernel! - # row_sums_tensor[0,0,0] has the row_sum for row 0 only. - # For per-row normalization, we'd need per-row row_sums. - # The kernel outputs row_sums[sfw_idx, 0, 0] for each of 128 rows. - # But we can also compute from the un-normalized O and the reference. - # - # Simpler: compute row_sum from the un-normalized O and the LSE. - # lse = ln(row_sum) + row_max * ln(2) - # exp(lse) = row_sum * exp(row_max * ln(2)) = row_sum * 2^row_max - # O_unnorm = P @ V where P = softmax * row_sum - # Wait, the kernel's P = exp2(S*scale - row_max) which is NOT softmax. - # softmax = P / row_sum - # O_unnorm = P @ V = softmax * row_sum @ V = O_norm * row_sum - # So: O_norm = O_unnorm / row_sum - # - # We need row_sum per row. The kernel outputs it in row_sums_tensor. - # But row_sums_tensor only has values for 128 rows, and we need all T rows. - # Actually T=128, so row_sums_tensor[0:128, 0, 0] has all rows. - # - # Let me extract per-row row_sums. - # But the kernel only writes row_sums[sfw_idx, 0, 0] for sfw_idx 0..127. - # And row_sums_tensor is (T, 1, 1) = (128, 1, 1). - # So row_sums_tensor[:, 0, 0] should have all 128 rows. - - # Re-run to get row_sums (we already have them from the last call) - row_sums_per_row = row_sums_tensor[:, 0, 0].float().unsqueeze(1) # (T, 1) - row_sums_per_row = row_sums_per_row.clamp(min=1e-30) - output_norm = output_unnorm / row_sums_per_row - return output_norm.to(torch.bfloat16).unsqueeze(0) # (1, T, hd) - - def _attention_single_head( q: torch.Tensor, # (1, T, hd) k: torch.Tensor, # (1, N, hd) @@ -229,24 +127,10 @@ def _attention_single_head( apply_swa_mask = swa_len is not None apply_sink_bias = sink_bias is not None - # Reference output - qf = q[0].float() # (T, hd) - kf = k[0].float() # (N, hd) - vf = v[0].float() # (N, hd) - attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] - attn_exp = torch.exp(qf @ kf.T * scale - attn_max) - attn_sum = attn_exp.sum(dim=-1, keepdim=True) - ref_norm = (attn_exp / attn_sum) @ vf - - # Each segment is s_k=128 (one KV tile) + # Segment the KV sequence into 128-entry tiles s_k_per_seg = 128 n_segments = (N + s_k_per_seg - 1) // s_k_per_seg - if n_segments == 1: - # Single segment: use kernel with normalize=True (in-kernel normalization) - return _attention_single_head_normalized(q, k, v, scale, swa_len, is_causal, n_comp, sink_bias, use_smem_p) - - # Multi-segment: use un-normalized output + Python KV merge compiled, kernel = _get_or_compile_kernel( head_dim=hd, s_k=s_k_per_seg, use_smem_p=use_smem_p, normalize=False, apply_swa_mask=apply_swa_mask, @@ -266,7 +150,7 @@ def _attention_single_head( k_seg = k[:, k_start:k_end] v_seg = v[:, k_start:k_end] - # If last segment is shorter than s_k_per_seg, pad + # Pad last segment if shorter than s_k_per_seg if k_end - k_start < s_k_per_seg: pad_len = s_k_per_seg - (k_end - k_start) k_seg = torch.cat([k_seg, torch.zeros(1, pad_len, hd, dtype=k.dtype, device='cuda')], dim=1) @@ -282,11 +166,9 @@ def _attention_single_head( v_kernel = v_tile.unsqueeze(-1) c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') lse_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda') - row_sums_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda') - # Prepare CuTe tensors - q_input = q[0].contiguous().unsqueeze(-1) # (T, hd, 1) - k_input = k_seg[0].contiguous().unsqueeze(-1) # (s_k_per_seg, hd, 1) + q_input = q[0].contiguous().unsqueeze(-1) + k_input = k_seg[0].contiguous().unsqueeze(-1) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) mQ = ct.from_dlpack(q_input).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_input)) @@ -294,7 +176,6 @@ def _attention_single_head( mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) - mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor)) compiled(mQ, mK, mV, mC, stream, lse=mLSE) torch.cuda.synchronize() @@ -303,10 +184,7 @@ def _attention_single_head( if nt == 0: seg_lse[:, 0] = lse_tensor[:, 0, 0].float() - # Merge with accumulator using log-sum-exp (same as test_d1_kv_merge.py) - # Formula: O = (exp(lse_old) * O_old + exp(lse_new) * O_new) / (exp(lse_old) + exp(lse_new)) - # Both O_old and O_new are un-normalized outputs from the kernel. - # This produces the correct normalized result after merge. + # Merge with accumulator using log-sum-exp e_old = torch.exp(lse_accum) e_new = torch.exp(seg_lse) e_sum = e_old + e_new @@ -314,5 +192,5 @@ def _attention_single_head( o_accum = (e_old * o_accum + e_new * seg_o) / e_sum lse_accum = torch.log(e_sum) - output = o_accum.to(torch.bfloat16).unsqueeze(0) # (1, T, hd) + output = o_accum.to(torch.bfloat16).unsqueeze(0) return output