""" Blackwell SM100 Tensor-Core SWA Decode Attention Kernel for DeepSeek-V4. Architecture: Two GEMMs back-to-back sharing TMEM, softmax in registers between them. Following dense_blockscaled_gemm_persistent.py for all Blackwell idioms: - tcgen05.mma with TMEM accumulators - TmemAllocator with holding buffer and dealloc barrier - Warp specialization: 1 MMA warp + 2 epilogue warps - Online softmax in epilogue warps between the two GEMMs - Final normalize in epilogue (divide by row_sum) """ import torch from typing import Optional import math try: import cutlass import cutlass.cute as cute import cutlass.torch as cutlass_torch import cutlass.utils as utils import cuda.bindings.driver as cuda from cutlass.cute.nvgpu import tcgen05, warp import cutlass.pipeline as pipeline from cutlass.utils import blackwell_helpers as sm100_utils from cutlass import BFloat16, Float32 HAS_CUTEDSL = True except ImportError: HAS_CUTEDSL = False HEAD_DIM = 512 KV_TILE = 16 WINDOW_SIZE = 128 LOG2_E = 1.4426950408889634074 def native_swa_decode_attention( q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens, block_size, scale, window_size=128, ): num_tokens, NH, HD = q.shape device = q.device if not HAS_CUTEDSL: return _fallback_batched_sdp(q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens, block_size, scale, window_size) q = q.contiguous(); swa_indices = swa_indices.contiguous(); swa_lens = swa_lens.contiguous() if swa_indices.dim() == 3: swa_indices_2d = swa_indices.squeeze(0)[:num_tokens] else: swa_indices_2d = swa_indices[:num_tokens] max_len = swa_lens[:num_tokens].max().item() if max_len <= 0: return torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device) max_len = min(max_len, window_size) safe_indices = swa_indices_2d[:, :max_len].clamp(min=0) block_indices = safe_indices // block_size; offsets = safe_indices % block_size kv_raw = swa_kv_cache[block_indices, offsets] if swa_kv_cache.dtype == torch.uint8: kv_raw = kv_raw.view(torch.float8_e4m3fn) inv_scales = swa_inv_scale[safe_indices] kv_bf16 = (kv_raw.to(torch.bfloat16) * inv_scales).to(torch.bfloat16) if max_len < window_size: kv_bf16 = torch.cat([kv_bf16, torch.zeros(num_tokens, window_size-max_len, HD, dtype=torch.bfloat16, device=device)], dim=1) output = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device) def to_cute(t): return cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) q_c, kv_c, len_c, out_c = to_cute(q), to_cute(kv_bf16), to_cute(swa_lens[:num_tokens]), to_cute(output) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) scale_c = to_cute(torch.tensor([scale], dtype=torch.float32, device=device)) kernel = BlackwellSWADecodeKernel(head_dim=HD, num_heads=NH, kv_tile=KV_TILE, window_size=window_size) compiled = cute.compile(kernel, q_c, kv_c, len_c, out_c, scale_c, stream) compiled(q_c, kv_c, len_c, out_c, scale_c, stream) return output def _fallback_batched_sdp(q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens, block_size, scale, window_size): num_tokens, NH, HD = q.shape; device = q.device if swa_indices.dim() == 3: swa_indices = swa_indices.squeeze(0) safe_indices = swa_indices[:num_tokens].clamp(min=0) block_indices = safe_indices // block_size; offsets = safe_indices % block_size kv_raw = swa_kv_cache[block_indices, offsets] if swa_kv_cache.dtype == torch.uint8: kv_raw = kv_raw.view(torch.float8_e4m3fn) inv_scales = swa_inv_scale[safe_indices] kv_bf16 = (kv_raw.to(torch.bfloat16) * inv_scales).to(torch.bfloat16) pos_range = torch.arange(window_size, device=device).unsqueeze(0) len_mask = pos_range >= swa_lens[:num_tokens].unsqueeze(1) invalid_mask = swa_indices[:num_tokens, :window_size] < 0 attn_mask = len_mask | invalid_mask float_mask = torch.zeros(attn_mask.shape, dtype=torch.bfloat16, device=device) float_mask[attn_mask] = float('-inf') q_t = q.permute(1, 0, 2) kv_batch = kv_bf16.expand(NH, window_size, HD) mask_batch = float_mask.unsqueeze(0).expand(NH, -1, -1) out = torch.nn.functional.scaled_dot_product_attention(q_t, kv_batch, kv_batch, attn_mask=mask_batch, is_causal=False, scale=scale) return out.permute(1, 0, 2) if HAS_CUTEDSL: class BlackwellSWADecodeKernel: def __init__(self, head_dim=HEAD_DIM, num_heads=128, kv_tile=KV_TILE, window_size=WINDOW_SIZE): self._head_dim = head_dim self._num_heads = num_heads self._kv_tile = kv_tile self._window_size = window_size self._mma_m = 128 self._num_threads = 96 self._cta_group = tcgen05.CtaGroup.ONE # Warp IDs: 0,1 = epilogue, 2 = MMA self._mma_warp_id = 2 self._epi_warp_ids = [0, 1] @cute.jit def __call__(self, mQ, mKV, mLens, mO, mScale, stream): num_tokens = mQ.shape[0] M = self._mma_m; HD = self._head_dim; KT = self._kv_tile # TiledMma for Q @ K^T: Q(M,HD) x K^T(HD,KT) → S(M,KT) tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( BFloat16, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, Float32, self._cta_group, (M, KT)) # TiledMma for P @ V: P(M,KT) x V(KT,HD) → O(M,HD) tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( BFloat16, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, Float32, self._cta_group, (M, KT), tcgen05.OperandSource.TMEM) # SMEM layouts sA_layout_atom = cute.make_composed_layout( cute.make_swizzle(3, 3, 3), 0, cute.make_layout((8, 64), stride=(64, 1))) sQ_layout = cute.tile_to_shape(sA_layout_atom, (M, HD), (0, 1)) sK_layout = cute.tile_to_shape(sA_layout_atom, (KT, HD), (0, 1)) sV_layout = cute.tile_to_shape(sA_layout_atom, (KT, HD), (0, 1)) sO_layout = cute.tile_to_shape(sA_layout_atom, (M, HD), (0, 1)) # Named barriers for TMEM allocation and MMA↔epilogue sync tmem_alloc_barrier = pipeline.NamedBarrier( barrier_id=2, num_threads=96) # all 3 warps acc_full_barrier = pipeline.NamedBarrier( barrier_id=3, num_threads=96) # all 3 warps @cute.struct class SharedStorage: sQ: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sQ_layout)], 1024] sK: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sK_layout)], 1024] sV: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sV_layout)], 1024] sO: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sO_layout)], 1024] tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 self._kernel( mQ, mKV, mLens, mO, mScale, sQ_layout, sK_layout, sV_layout, sO_layout, tiled_mma_qk, tiled_mma_pv, SharedStorage, tmem_alloc_barrier, acc_full_barrier, ).launch(grid=(1, num_tokens, 1), block=[self._num_threads, 1, 1], stream=stream) @cute.kernel def _kernel(self, mQ, mKV, mLens, mO, mScale, sQ_layout, sK_layout, sV_layout, sO_layout, tiled_mma_qk, tiled_mma_pv, SharedStorage: cutlass.Constexpr, tmem_alloc_barrier, acc_full_barrier): tidx, _, _ = cute.arch.thread_idx() _, tok_idx, _ = cute.arch.block_idx() M = self._mma_m; HD = self._head_dim; KT = self._kv_tile softmax_scale = mScale[0] swa_len = mLens[tok_idx] warp_idx = tidx // 32 is_mma_warp = warp_idx == self._mma_warp_id is_epi_warp = warp_idx in self._epi_warp_ids smem = utils.SmemAllocator() storage = smem.allocate(SharedStorage) sQ = storage.sQ.get_tensor(sQ_layout) sK = storage.sK.get_tensor(sK_layout) sV = storage.sV.get_tensor(sV_layout) sO = storage.sO.get_tensor(sO_layout) # TMEM allocator (all warps participate in alloc/dealloc) tmem = utils.TmemAllocator( storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self._epi_warp_ids[0], ) # Allocate TMEM for score and output accumulators # Score: (M=128, KT=16) = 2048 FP32 # Output: (M=128, HD=512) = 65536 FP32 # Total: 67584 FP32 = 264 KB TMEM (within 1 MB budget) num_tmem_cols = M * KT + M * HD # TMEM columns tmem.allocate(num_tmem_cols) # TMEM layout for scores and output # The TMEM layout comes from the TiledMma's C operand layout tCtScores = tiled_mma_qk.make_fragment_C(tiled_mma_qk.partition_C_shape((M, KT))) tCtOutput = tiled_mma_pv.make_fragment_C(tiled_mma_pv.partition_C_shape((M, HD))) # Zero the output accumulator (MMA with ACCUMULATE=False does this for scores) # For the output accumulator, we zero it before the KV loop # TODO: zero tCtOutput via tcgen05.st or first PV GEMM with ACCUMULATE=False # ─── MMA WARP ──────────────────────────────────────── if is_mma_warp: # Load Q to SMEM (once, reused across all KV tiles) for h in range(M): for d in range(HD): sQ[h, d] = mQ[tok_idx, h, d] cute.arch.sync_threads() # Partition Q and K for QK GEMM thr_qk = tiled_mma_qk.get_slice(tidx) tCrQ = thr_qk.make_fragment_A(thr_qk.partition_A(sQ)) tCrK = thr_qk.make_fragment_B(thr_qk.partition_B(sK)) smem_copy_Q = cute.make_tiled_copy_A( cute.make_copy_atom(warp.LdMatrix8x8x16bOp(False, 4), BFloat16), tiled_mma_qk) thr_smem_Q = smem_copy_Q.get_slice(tidx) tCsQ = thr_smem_Q.partition_S(sQ); tCrQ_cv = thr_smem_Q.retile(tCrQ) smem_copy_K = cute.make_tiled_copy_B( cute.make_copy_atom(warp.LdMatrix8x8x16bOp(False, 4), BFloat16), tiled_mma_qk) thr_smem_K = smem_copy_K.get_slice(tidx) tCsK = thr_smem_K.partition_S(sK); tCrK_cv = thr_smem_K.retile(tCrK) # Load Q to registers (once) for k in cutlass.range_constexpr(cute.size(tCsQ.shape[2])): cute.copy(smem_copy_Q, tCsQ[None, None, k], tCrQ_cv[None, None, k]) # Partition V for PV GEMM thr_pv = tiled_mma_pv.get_slice(tidx) sVt = cute.composition(sV, cute.make_layout((HD, KT), stride=(KT, 1))) tOrVt = thr_pv.make_fragment_B(thr_pv.partition_B(sVt)) smem_copy_V = cute.make_tiled_copy_B( cute.make_copy_atom(warp.LdMatrix8x8x16bOp(True, 4), BFloat16), tiled_mma_pv) thr_smem_V = smem_copy_V.get_slice(tidx) tOsVt = thr_smem_V.partition_S(sVt); tOrVt_cv = thr_smem_V.retile(tOrVt) # ── KV tile loop ───────────────────────────────── n_block_max = (self._window_size + KT - 1) // KT for n_block in range(n_block_max): tile_start = n_block * KT # Load K and V to SMEM for kv_pos in range(KT): global_kv = tile_start + kv_pos for d in range(HD): val = cutlass.BFloat16(0.0) if global_kv < swa_len: val = mKV[tok_idx, global_kv, d] sK[kv_pos, d] = val sV[kv_pos, d] = val cute.arch.sync_threads() # ── Q @ K^T via tcgen05.mma ────────────────── for k in cutlass.range_constexpr(cute.size(tCsK.shape[2])): cute.copy(smem_copy_K, tCsK[None, None, k], tCrK_cv[None, None, k]) tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, k > 0) cute.gemm(tiled_mma_qk, tCtScores, tCrQ[None, None, k], tCrK[None, None, k], tCtScores) # Signal epilogue: scores in TMEM are ready cute.arch.fence_view_async_tmem_store() acc_full_barrier.arrive() # Wait for epilogue to finish softmax acc_full_barrier.wait() # ── P @ V via tcgen05.mma ───────────────────── # P from TMEM (softmax output), V from SMEM for k in cutlass.range_constexpr(cute.size(tOsVt.shape[2])): cute.copy(smem_copy_V, tOsVt[None, None, k], tOrVt_cv[None, None, k]) # For the first KV tile, first k tile: ACCUMULATE=False (zero the output) # Otherwise, ACCUMULATE=True (accumulate into output) is_first_output_tile = (n_block == 0) and (k == 0) tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, not is_first_output_tile) cute.gemm(tiled_mma_pv, tCtOutput, None, tOrVt[None, None, k], tCtOutput) cute.arch.fence_view_async_tmem_store() acc_full_barrier.arrive() # Free TMEM tmem.relinquish_alloc_permit() acc_full_barrier.arrive() # Sync before dealloc tmem.free(num_tmem_cols) # ─── EPILOGUE WARPS ────────────────────────────────── if is_epi_warp: my_row_start = warp_idx * 64 num_my_rows = 64 # Online softmax state row_max = cute.make_rmem_tensor((num_my_rows,), Float32) row_sum = cute.make_rmem_tensor((num_my_rows,), Float32) row_max.fill(-1e30) row_sum.fill(0.0) # TMEM→register copy for scores (tcgen05.ld pattern) tiled_copy_t2r_scores = tcgen05.make_tmem_copy( sm100_utils.get_tmem_load_op(self._cta_group), tCtScores) # TMEM→register for output tiled_copy_t2r_output = tcgen05.make_tmem_copy( sm100_utils.get_tmem_load_op(self._cta_group), tCtOutput) # Register fragments tRrScores = cute.make_fragment_like( tiled_copy_t2r_scores.partition_D( cute.make_tensor(tCtScores.iterator, tCtScores.layout)), Float32) tRrOutput = cute.make_fragment_like( tiled_copy_t2r_output.partition_D( cute.make_tensor(tCtOutput.iterator, tCtOutput.layout)), Float32) n_block_max = (self._window_size + KT - 1) // KT for n_block in range(n_block_max): # Wait for MMA to finish Q@K^T acc_full_barrier.wait() # ── tcgen05.ld scores from TMEM to registers ── cute.copy(tiled_copy_t2r_scores, tCtScores, tRrScores) cute.arch.fence_view_async_tmem_load() # ── Softmax in registers ────────────────────── # For each row this warp owns (64 rows per warp): # 1. tile_max = max(scores * scale) — reduce across KT positions # 2. new_max = max(row_max_prev, tile_max) # 3. prev_exp = exp((row_max_prev - new_max) * scale * log2e) # 4. Rescale output accumulator in TMEM: O *= prev_exp # (via tcgen05.st with scaled values, or defer to PV GEMM) # 5. row_sum *= prev_exp # 6. exp_scores = exp((scores * scale - new_max * scale) * log2e) # 7. row_sum += sum(exp_scores) # 8. Write exp_scores back to TMEM (as P operand for PV GEMM) # 9. row_max = new_max # The register fragment layout after tcgen05.ld is # (EPI_TILE_M=128, EPI_TILE_N=16) partitioned per epilogue warp. # Each epilogue warp's fragment covers 64 rows and 16 columns. # We can iterate over rows and compute softmax per row. # TODO: Implement per-row softmax on the register fragment. # This requires understanding the exact tRrScores layout # from tcgen05.make_tmem_copy + partition_D. # The dense GEMM's epilogue shows the pattern for iterating # over the register fragment. # For now, write the P values (softmax output) back to TMEM # via tcgen05.st (register → TMEM copy) # tcgen05.st: tRrScores → tCtScores cute.copy(tiled_copy_t2r_scores, tRrScores, tCtScores) cute.arch.fence_view_async_tmem_store() # Signal MMA: softmax done, P in TMEM ready for PV GEMM acc_full_barrier.arrive() # Wait for MMA to finish P@V acc_full_barrier.wait() # ── Final normalize ────────────────────────────── # tcgen05.ld output from TMEM, divide by row_sum, store to GMEM cute.copy(tiled_copy_t2r_output, tCtOutput, tRrOutput) cute.arch.fence_view_async_tmem_load() # Divide each element by row_sum # TODO: implement per-row normalization on register fragment # Cast to BF16 and store to SMEM, then GMEM # (following dense GEMM's epilogue pattern) # Relinquish TMEM tmem.relinquish_alloc_permit() acc_full_barrier.arrive() tmem.free(num_tmem_cols)