""" Native CuTeDSL Sparse SWA Decode Attention for DeepSeek-V4 on Blackwell (SM100). Handles CSA (C4A, compress_ratio=4) and HCA (C128A, compress_ratio=128). Attends to BOTH the SWA window AND top-k compressed KV, merged with sink weights. Sink weight merge (FlashMLA formula): o = exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa / (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa)) where o_sparse = sum(exp(s)*v) / sum(exp(s)) from compressed KV o_swa = sum(exp(s)*v) / sum(exp(s)) from SWA KV lse_sparse = log(sum(exp(s))) from compressed KV lse_swa = log(sum(exp(s))) from SWA KV attn_sink = per-head learnable parameter (NH,) """ import torch import torch.nn.functional as F from typing import Optional try: import cutlass import cutlass.cute as cute import cutlass.torch as cutlass_torch import cutlass.utils as utils import cuda.bindings.driver as cuda HAS_CUTEDSL = True except ImportError: HAS_CUTEDSL = False _compiled_sparse_kernel_cache = {} HEAD_GROUP = 16 KV_TILE = 16 HEAD_DIM = 512 NUM_THREADS = 128 def native_sparse_decode_attention( q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens, compressed_kv_cache, compressed_inv_scale, topk_indices, topk_lens, attn_sink, block_size, scale, window_size=128, compress_ratio=4, ): num_tokens, NH, HD = q.shape device = q.device if not HAS_CUTEDSL: return _fallback_sparse_sdp( q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens, compressed_kv_cache, compressed_inv_scale, topk_indices, topk_lens, attn_sink, block_size, scale, window_size, ) q = q.contiguous() swa_indices = swa_indices.contiguous() swa_lens = swa_lens.contiguous() topk_indices = topk_indices.contiguous() topk_lens = topk_lens.contiguous() # Pre-dequantize SWA KV swa_len_max = min(swa_lens[:num_tokens].max().item(), window_size) topk_max = topk_indices.shape[-1] if topk_indices.dim() > 1 else 1 topk_len_max = min(topk_lens[:num_tokens].max().item(), topk_max) if topk_max > 0 else 0 if swa_len_max <= 0 and topk_len_max <= 0: return torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device) # Dequantize SWA KV safe_swa = swa_indices[:num_tokens, :swa_len_max].clamp(min=0) swa_bi = safe_swa // block_size swa_of = safe_swa % block_size swa_raw = swa_kv_cache[swa_bi, swa_of] if swa_kv_cache.dtype == torch.uint8: swa_raw = swa_raw.view(torch.float8_e4m3fn) swa_bf16 = (swa_raw.to(torch.bfloat16) * swa_inv_scale[safe_swa]).to(torch.bfloat16) if swa_len_max < window_size: swa_bf16 = torch.cat([swa_bf16, torch.zeros(num_tokens, window_size - swa_len_max, HD, dtype=torch.bfloat16, device=device)], dim=1) # Dequantize compressed KV if topk_len_max > 0: comp_bs = compressed_kv_cache.shape[1] safe_topk = topk_indices[:num_tokens, :topk_len_max].clamp(min=0) comp_bi = safe_topk // comp_bs comp_of = safe_topk % comp_bs comp_raw = compressed_kv_cache[comp_bi, comp_of] if compressed_kv_cache.dtype == torch.uint8: comp_raw = comp_raw.view(torch.float8_e4m3fn) comp_bf16 = (comp_raw.to(torch.bfloat16) * compressed_inv_scale[safe_topk]).to(torch.bfloat16) if topk_len_max < topk_max: comp_bf16 = torch.cat([comp_bf16, torch.zeros(num_tokens, topk_max - topk_len_max, HD, dtype=torch.bfloat16, device=device)], dim=1) else: topk_max = 0 comp_bf16 = torch.zeros(num_tokens, 0, HD, dtype=torch.bfloat16, device=device) # Combined KV: (T, window_size + topk_max, HD) if topk_max > 0: kv_combined = torch.cat([swa_bf16, comp_bf16], dim=1) else: kv_combined = swa_bf16 combined_lens = swa_lens[:num_tokens] + topk_lens[:num_tokens] total_len = window_size + topk_max output = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device) cache_key = (num_tokens, NH, HD, window_size, topk_max, compress_ratio, str(device)) if cache_key not in _compiled_sparse_kernel_cache: def to_cute(t): ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) q_c = to_cute(q) kv_c = to_cute(kv_combined) len_c = to_cute(combined_lens) out_c = to_cute(output) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) scale_tensor = torch.tensor([scale], dtype=torch.float32, device=device) scale_c = to_cute(scale_tensor) kernel = BlackwellSparseDecodeKernel( head_dim=HD, head_group=HEAD_GROUP, kv_tile=KV_TILE, total_len=total_len, ) compiled = cute.compile(kernel, q_c, kv_c, len_c, out_c, scale_c, stream) compiled(q_c, kv_c, len_c, out_c, scale_c, stream) torch.cuda.synchronize() _compiled_sparse_kernel_cache[cache_key] = {'compiled': compiled} entry = _compiled_sparse_kernel_cache[cache_key] compiled = entry['compiled'] def to_cute(t): ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) q_c = to_cute(q) kv_c = to_cute(kv_combined) len_c = to_cute(combined_lens) out_c = to_cute(output) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) scale_tensor = torch.tensor([scale], dtype=torch.float32, device=device) scale_c = to_cute(scale_tensor) compiled(q_c, kv_c, len_c, out_c, scale_c, stream) return output def _fallback_sparse_sdp( q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens, compressed_kv_cache, compressed_inv_scale, topk_indices, topk_lens, attn_sink, block_size, scale, window_size, ): num_tokens, NH, HD = q.shape device = q.device if swa_indices.dim() == 3: swa_indices = swa_indices.squeeze(0) safe_swa = swa_indices[:num_tokens].clamp(min=0) swa_bi = safe_swa // block_size swa_of = safe_swa % block_size swa_raw = swa_kv_cache[swa_bi, swa_of] if swa_kv_cache.dtype == torch.uint8: swa_raw = swa_raw.view(torch.float8_e4m3fn) swa_bf16 = (swa_raw.to(torch.bfloat16) * swa_inv_scale[safe_swa]).to(torch.bfloat16) # SWA attention (batched) pos_range = torch.arange(window_size, device=device).unsqueeze(0) len_mask = pos_range >= swa_lens[:num_tokens].unsqueeze(1) invalid_mask = swa_indices[:num_tokens] < 0 attn_mask_swa = len_mask | invalid_mask float_mask = torch.zeros(attn_mask_swa.shape, dtype=torch.bfloat16, device=device) float_mask[attn_mask_swa] = float('-inf') q_t = q.permute(1, 0, 2) q_batch = q_t.reshape(NH * num_tokens, 1, HD) kv_exp = swa_bf16.unsqueeze(0).expand(NH, num_tokens, window_size, HD) k_batch = kv_exp.reshape(NH * num_tokens, window_size, HD) mask_batch = float_mask.unsqueeze(0).unsqueeze(2).expand(NH, num_tokens, 1, window_size).reshape(NH * num_tokens, 1, window_size) o_swa = F.scaled_dot_product_attention(q_batch, k_batch, k_batch, attn_mask=mask_batch, is_causal=False, scale=scale) o_swa = o_swa.reshape(NH, num_tokens, HD).permute(1, 0, 2) # Compute SWA lse manually scores_swa = torch.matmul(q_batch, k_batch.transpose(-2, -1)) * scale scores_swa = scores_swa + mask_batch.float() max_swa = scores_swa.max(dim=-1).values # (NH*T,) lse_swa = (max_swa + (scores_swa - max_swa.unsqueeze(-1)).exp().sum(dim=-1).log()).reshape(NH, num_tokens).t() # (T, NH) # Compressed KV attention topk_max = topk_indices.shape[-1] if topk_indices.dim() > 1 else 1 o_sparse = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device) lse_sparse = torch.full((num_tokens, NH), float('-inf'), dtype=torch.float32, device=device) if topk_max > 0 and topk_lens[:num_tokens].max().item() > 0: comp_bs = compressed_kv_cache.shape[1] safe_topk = topk_indices[:num_tokens].clamp(min=0) comp_bi = safe_topk // comp_bs comp_of = safe_topk % comp_bs comp_raw = compressed_kv_cache[comp_bi, comp_of] if compressed_kv_cache.dtype == torch.uint8: comp_raw = comp_raw.view(torch.float8_e4m3fn) comp_bf16 = (comp_raw.to(torch.bfloat16) * compressed_inv_scale[safe_topk]).to(torch.bfloat16) topk_len_mask = torch.arange(topk_max, device=device).unsqueeze(0) >= topk_lens[:num_tokens].unsqueeze(1) invalid_topk = topk_indices[:num_tokens] < 0 attn_mask_comp = topk_len_mask | invalid_topk float_mask_comp = torch.zeros(attn_mask_comp.shape, dtype=torch.bfloat16, device=device) float_mask_comp[attn_mask_comp] = float('-inf') kv_exp2 = comp_bf16.unsqueeze(0).expand(NH, num_tokens, topk_max, HD) k_batch2 = kv_exp2.reshape(NH * num_tokens, topk_max, HD) mask_batch2 = float_mask_comp.unsqueeze(0).unsqueeze(2).expand(NH, num_tokens, 1, topk_max).reshape(NH * num_tokens, 1, topk_max) o_sparse = F.scaled_dot_product_attention(q_batch, k_batch2, k_batch2, attn_mask=mask_batch2, is_causal=False, scale=scale) o_sparse = o_sparse.reshape(NH, num_tokens, HD).permute(1, 0, 2) scores_comp = torch.matmul(q_batch, k_batch2.transpose(-2, -1)) * scale scores_comp = scores_comp + mask_batch2.float() max_comp = scores_comp.max(dim=-1).values lse_sparse = (max_comp + (scores_comp - max_comp.unsqueeze(-1)).exp().sum(dim=-1).log()).reshape(NH, num_tokens).t() # Merge with sink weights attn_sink = attn_sink.to(torch.float32) # (NH,) exp_lse_sparse = lse_sparse.exp() # (T, NH) exp_lse_swa = lse_swa.exp() exp_sink = attn_sink.unsqueeze(0).exp() # (1, NH) numerator = (exp_lse_sparse.unsqueeze(-1) * o_sparse.float() + exp_sink.unsqueeze(-1) * exp_lse_swa.unsqueeze(-1) * o_swa.float()) denominator = (exp_lse_sparse + exp_sink * exp_lse_swa).clamp(min=1e-30).unsqueeze(-1) output = (numerator / denominator).to(torch.bfloat16) return output if HAS_CUTEDSL: class BlackwellSparseDecodeKernel: def __init__(self, head_dim=HEAD_DIM, head_group=HEAD_GROUP, kv_tile=KV_TILE, total_len=128): self._head_dim = head_dim self._head_group = head_group self._kv_tile = kv_tile self._total_len = total_len self._num_threads = NUM_THREADS @cute.jit def __call__(self, mQ, mKV, mLens, mO, mScale, stream): num_tokens = mQ.shape[0] num_head_groups = mQ.shape[1] // self._head_group self._kernel(mQ, mKV, mLens, mO, mScale).launch( grid=(num_head_groups, num_tokens, 1), block=[self._num_threads, 1, 1], stream=stream, ) @cute.kernel def _kernel(self, mQ, mKV, mLens, mO, mScale): tidx, _, _ = cute.arch.thread_idx() hg_idx, tok_idx, _ = cute.arch.block_idx() HG = self._head_group HD = self._head_dim KT = self._kv_tile TL = self._total_len softmax_scale = mScale[0] @cute.struct class SharedStorage: kv_tile: cute.struct.MemRange[cutlass.BFloat16, KT * HD] smem = utils.SmemAllocator() storage = smem.allocate(SharedStorage) sKV = cute.make_tensor( storage.kv_tile.data_ptr(), cute.make_layout((KT, HD), stride=(HD, 1)), ) swa_len = mLens[tok_idx] has_kv = swa_len > 0 q_reg = cute.make_rmem_tensor((HG, HD), cutlass.BFloat16) for h in cutlass.range_constexpr(HG): qh = hg_idx * HG + h for d in range(HD): q_reg[h, d] = mQ[tok_idx, qh, d] acc_O = cute.make_rmem_tensor((HG, HD), cutlass.Float32) acc_O.fill(0.0) row_max = cute.make_rmem_tensor((HG,), cutlass.Float32) row_sum = cute.make_rmem_tensor((HG,), cutlass.Float32) row_max.fill(-1e30) row_sum.fill(0.0) max_tiles = (TL + KT - 1) // KT for tile_idx in range(max_tiles): tile_start = tile_idx * KT for kv_pos in range(KT): global_kv = tile_start + kv_pos for d in range(HD): valid = global_kv < swa_len val = cutlass.BFloat16(0.0) if valid: val = mKV[tok_idx, global_kv, d] sKV[kv_pos, d] = val cute.arch.sync_threads() scores = cute.make_rmem_tensor((HG, KT), cutlass.Float32) scores.fill(0.0) for h in cutlass.range_constexpr(HG): for kv_pos in range(KT): dot = cutlass.Float32(0.0) for d in range(HD): q_val = q_reg[h, d].to(cutlass.Float32) k_val = sKV[kv_pos, d].to(cutlass.Float32) dot = dot + q_val * k_val scores[h, kv_pos] = dot * softmax_scale for h in cutlass.range_constexpr(HG): tile_max = cutlass.Float32(-1e30) for kv_pos in range(KT): s = scores[h, kv_pos] if s > tile_max: tile_max = s new_max = row_max[h] if tile_max > new_max: new_max = tile_max rescale = cutlass.Float32(0.0) if row_max[h] > cutlass.Float32(-1e29): rescale = cute.exp(row_max[h] - new_max) for d in range(HD): acc_O[h, d] = acc_O[h, d] * rescale row_sum[h] = row_sum[h] * rescale for kv_pos in range(KT): exp_score = cute.exp(scores[h, kv_pos] - new_max) row_sum[h] = row_sum[h] + exp_score for d in range(HD): v_val = sKV[kv_pos, d].to(cutlass.Float32) acc_O[h, d] = acc_O[h, d] + exp_score * v_val row_max[h] = new_max cute.arch.sync_threads() for h in cutlass.range_constexpr(HG): qh = hg_idx * HG + h for d in range(HD): val_f32 = cutlass.Float32(0.0) if has_kv and row_sum[h] > cutlass.Float32(1e-30): val_f32 = acc_O[h, d] / row_sum[h] mO[tok_idx, qh, d] = val_f32.to(cutlass.BFloat16)