Key fixes: - PipelineUmmaAsync consumer group: 32*4=128 threads (not 4 warps) - TMEM offsets computed from find_tmem_tensor_col_offset (not hardcoded) - P fragment from p_tmem_s.outer + make_fragment_A (matching fmha.py) - V SMEM aliasing via recast_ptr Status: - Stage A: cosine 0.999999 ✅ - Stage B: runs without crash, identity softmax cosine -0.02 ❌ - Diagnostics: TMEM layout inspection, bisection results
376 lines
18 KiB
Python
376 lines
18 KiB
Python
"""
|
|
Blackwell SM100 Tensor-Core SWA Decode Attention Kernel for DeepSeek-V4.
|
|
|
|
Architecture: Two GEMMs back-to-back sharing TMEM, softmax in registers between them.
|
|
Following dense_blockscaled_gemm_persistent.py for all Blackwell idioms:
|
|
- tcgen05.mma with TMEM accumulators
|
|
- TmemAllocator with holding buffer and dealloc barrier
|
|
- Warp specialization: 1 MMA warp + 2 epilogue warps
|
|
- Online softmax in epilogue warps between the two GEMMs
|
|
- Final normalize in epilogue (divide by row_sum)
|
|
"""
|
|
|
|
import torch
|
|
from typing import Optional
|
|
import math
|
|
|
|
try:
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as cutlass_torch
|
|
import cutlass.utils as utils
|
|
import cuda.bindings.driver as cuda
|
|
from cutlass.cute.nvgpu import tcgen05, warp
|
|
import cutlass.pipeline as pipeline
|
|
from cutlass.utils import blackwell_helpers as sm100_utils
|
|
from cutlass import BFloat16, Float32
|
|
HAS_CUTEDSL = True
|
|
except ImportError:
|
|
HAS_CUTEDSL = False
|
|
|
|
HEAD_DIM = 512
|
|
KV_TILE = 16
|
|
WINDOW_SIZE = 128
|
|
LOG2_E = 1.4426950408889634074
|
|
|
|
|
|
def native_swa_decode_attention(
|
|
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
|
|
block_size, scale, window_size=128,
|
|
):
|
|
num_tokens, NH, HD = q.shape
|
|
device = q.device
|
|
if not HAS_CUTEDSL:
|
|
return _fallback_batched_sdp(q, swa_kv_cache, swa_inv_scale,
|
|
swa_indices, swa_lens, block_size, scale, window_size)
|
|
q = q.contiguous(); swa_indices = swa_indices.contiguous(); swa_lens = swa_lens.contiguous()
|
|
if swa_indices.dim() == 3: swa_indices_2d = swa_indices.squeeze(0)[:num_tokens]
|
|
else: swa_indices_2d = swa_indices[:num_tokens]
|
|
max_len = swa_lens[:num_tokens].max().item()
|
|
if max_len <= 0: return torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
|
|
max_len = min(max_len, window_size)
|
|
safe_indices = swa_indices_2d[:, :max_len].clamp(min=0)
|
|
block_indices = safe_indices // block_size; offsets = safe_indices % block_size
|
|
kv_raw = swa_kv_cache[block_indices, offsets]
|
|
if swa_kv_cache.dtype == torch.uint8: kv_raw = kv_raw.view(torch.float8_e4m3fn)
|
|
inv_scales = swa_inv_scale[safe_indices]
|
|
kv_bf16 = (kv_raw.to(torch.bfloat16) * inv_scales).to(torch.bfloat16)
|
|
if max_len < window_size:
|
|
kv_bf16 = torch.cat([kv_bf16, torch.zeros(num_tokens, window_size-max_len, HD, dtype=torch.bfloat16, device=device)], dim=1)
|
|
output = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
|
|
def to_cute(t): return cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
|
q_c, kv_c, len_c, out_c = to_cute(q), to_cute(kv_bf16), to_cute(swa_lens[:num_tokens]), to_cute(output)
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
scale_c = to_cute(torch.tensor([scale], dtype=torch.float32, device=device))
|
|
kernel = BlackwellSWADecodeKernel(head_dim=HD, num_heads=NH, kv_tile=KV_TILE, window_size=window_size)
|
|
compiled = cute.compile(kernel, q_c, kv_c, len_c, out_c, scale_c, stream)
|
|
compiled(q_c, kv_c, len_c, out_c, scale_c, stream)
|
|
return output
|
|
|
|
|
|
def _fallback_batched_sdp(q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
|
|
block_size, scale, window_size):
|
|
num_tokens, NH, HD = q.shape; device = q.device
|
|
if swa_indices.dim() == 3: swa_indices = swa_indices.squeeze(0)
|
|
safe_indices = swa_indices[:num_tokens].clamp(min=0)
|
|
block_indices = safe_indices // block_size; offsets = safe_indices % block_size
|
|
kv_raw = swa_kv_cache[block_indices, offsets]
|
|
if swa_kv_cache.dtype == torch.uint8: kv_raw = kv_raw.view(torch.float8_e4m3fn)
|
|
inv_scales = swa_inv_scale[safe_indices]
|
|
kv_bf16 = (kv_raw.to(torch.bfloat16) * inv_scales).to(torch.bfloat16)
|
|
pos_range = torch.arange(window_size, device=device).unsqueeze(0)
|
|
len_mask = pos_range >= swa_lens[:num_tokens].unsqueeze(1)
|
|
invalid_mask = swa_indices[:num_tokens, :window_size] < 0
|
|
attn_mask = len_mask | invalid_mask
|
|
float_mask = torch.zeros(attn_mask.shape, dtype=torch.bfloat16, device=device)
|
|
float_mask[attn_mask] = float('-inf')
|
|
q_t = q.permute(1, 0, 2)
|
|
kv_batch = kv_bf16.expand(NH, window_size, HD)
|
|
mask_batch = float_mask.unsqueeze(0).expand(NH, -1, -1)
|
|
out = torch.nn.functional.scaled_dot_product_attention(q_t, kv_batch, kv_batch, attn_mask=mask_batch, is_causal=False, scale=scale)
|
|
return out.permute(1, 0, 2)
|
|
|
|
|
|
if HAS_CUTEDSL:
|
|
|
|
class BlackwellSWADecodeKernel:
|
|
def __init__(self, head_dim=HEAD_DIM, num_heads=128,
|
|
kv_tile=KV_TILE, window_size=WINDOW_SIZE):
|
|
self._head_dim = head_dim
|
|
self._num_heads = num_heads
|
|
self._kv_tile = kv_tile
|
|
self._window_size = window_size
|
|
self._mma_m = 128
|
|
self._num_threads = 96
|
|
self._cta_group = tcgen05.CtaGroup.ONE
|
|
# Warp IDs: 0,1 = epilogue, 2 = MMA
|
|
self._mma_warp_id = 2
|
|
self._epi_warp_ids = [0, 1]
|
|
|
|
@cute.jit
|
|
def __call__(self, mQ, mKV, mLens, mO, mScale, stream):
|
|
num_tokens = mQ.shape[0]
|
|
M = self._mma_m; HD = self._head_dim; KT = self._kv_tile
|
|
|
|
# TiledMma for Q @ K^T: Q(M,HD) x K^T(HD,KT) → S(M,KT)
|
|
tiled_mma_qk = sm100_utils.make_trivial_tiled_mma(
|
|
BFloat16, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K,
|
|
Float32, self._cta_group, (M, KT))
|
|
|
|
# TiledMma for P @ V: P(M,KT) x V(KT,HD) → O(M,HD)
|
|
tiled_mma_pv = sm100_utils.make_trivial_tiled_mma(
|
|
BFloat16, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN,
|
|
Float32, self._cta_group, (M, KT),
|
|
tcgen05.OperandSource.TMEM)
|
|
|
|
# SMEM layouts
|
|
sA_layout_atom = cute.make_composed_layout(
|
|
cute.make_swizzle(3, 3, 3), 0,
|
|
cute.make_layout((8, 64), stride=(64, 1)))
|
|
sQ_layout = cute.tile_to_shape(sA_layout_atom, (M, HD), (0, 1))
|
|
sK_layout = cute.tile_to_shape(sA_layout_atom, (KT, HD), (0, 1))
|
|
sV_layout = cute.tile_to_shape(sA_layout_atom, (KT, HD), (0, 1))
|
|
sO_layout = cute.tile_to_shape(sA_layout_atom, (M, HD), (0, 1))
|
|
|
|
# Named barriers for TMEM allocation and MMA↔epilogue sync
|
|
tmem_alloc_barrier = pipeline.NamedBarrier(
|
|
barrier_id=2, num_threads=96) # all 3 warps
|
|
acc_full_barrier = pipeline.NamedBarrier(
|
|
barrier_id=3, num_threads=96) # all 3 warps
|
|
|
|
@cute.struct
|
|
class SharedStorage:
|
|
sQ: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sQ_layout)], 1024]
|
|
sK: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sK_layout)], 1024]
|
|
sV: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sV_layout)], 1024]
|
|
sO: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sO_layout)], 1024]
|
|
tmem_dealloc_mbar: cutlass.Int64
|
|
tmem_holding_buf: cutlass.Int32
|
|
|
|
self._kernel(
|
|
mQ, mKV, mLens, mO, mScale,
|
|
sQ_layout, sK_layout, sV_layout, sO_layout,
|
|
tiled_mma_qk, tiled_mma_pv, SharedStorage,
|
|
tmem_alloc_barrier, acc_full_barrier,
|
|
).launch(grid=(1, num_tokens, 1), block=[self._num_threads, 1, 1], stream=stream)
|
|
|
|
@cute.kernel
|
|
def _kernel(self, mQ, mKV, mLens, mO, mScale,
|
|
sQ_layout, sK_layout, sV_layout, sO_layout,
|
|
tiled_mma_qk, tiled_mma_pv, SharedStorage: cutlass.Constexpr,
|
|
tmem_alloc_barrier, acc_full_barrier):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
_, tok_idx, _ = cute.arch.block_idx()
|
|
|
|
M = self._mma_m; HD = self._head_dim; KT = self._kv_tile
|
|
softmax_scale = mScale[0]
|
|
swa_len = mLens[tok_idx]
|
|
|
|
warp_idx = tidx // 32
|
|
is_mma_warp = warp_idx == self._mma_warp_id
|
|
is_epi_warp = warp_idx in self._epi_warp_ids
|
|
|
|
smem = utils.SmemAllocator()
|
|
storage = smem.allocate(SharedStorage)
|
|
sQ = storage.sQ.get_tensor(sQ_layout)
|
|
sK = storage.sK.get_tensor(sK_layout)
|
|
sV = storage.sV.get_tensor(sV_layout)
|
|
sO = storage.sO.get_tensor(sO_layout)
|
|
|
|
# TMEM allocator (all warps participate in alloc/dealloc)
|
|
tmem = utils.TmemAllocator(
|
|
storage.tmem_holding_buf.ptr,
|
|
barrier_for_retrieve=tmem_alloc_barrier,
|
|
allocator_warp_id=self._epi_warp_ids[0],
|
|
)
|
|
|
|
# Allocate TMEM for score and output accumulators
|
|
# Score: (M=128, KT=16) = 2048 FP32
|
|
# Output: (M=128, HD=512) = 65536 FP32
|
|
# Total: 67584 FP32 = 264 KB TMEM (within 1 MB budget)
|
|
num_tmem_cols = M * KT + M * HD # TMEM columns
|
|
tmem.allocate(num_tmem_cols)
|
|
|
|
# TMEM layout for scores and output
|
|
# The TMEM layout comes from the TiledMma's C operand layout
|
|
tCtScores = tiled_mma_qk.make_fragment_C(tiled_mma_qk.partition_C_shape((M, KT)))
|
|
tCtOutput = tiled_mma_pv.make_fragment_C(tiled_mma_pv.partition_C_shape((M, HD)))
|
|
|
|
# Zero the output accumulator (MMA with ACCUMULATE=False does this for scores)
|
|
# For the output accumulator, we zero it before the KV loop
|
|
# TODO: zero tCtOutput via tcgen05.st or first PV GEMM with ACCUMULATE=False
|
|
|
|
# ─── MMA WARP ────────────────────────────────────────
|
|
if is_mma_warp:
|
|
# Load Q to SMEM (once, reused across all KV tiles)
|
|
for h in range(M):
|
|
for d in range(HD):
|
|
sQ[h, d] = mQ[tok_idx, h, d]
|
|
cute.arch.sync_threads()
|
|
|
|
# Partition Q and K for QK GEMM
|
|
thr_qk = tiled_mma_qk.get_slice(tidx)
|
|
tCrQ = thr_qk.make_fragment_A(thr_qk.partition_A(sQ))
|
|
tCrK = thr_qk.make_fragment_B(thr_qk.partition_B(sK))
|
|
|
|
smem_copy_Q = cute.make_tiled_copy_A(
|
|
cute.make_copy_atom(warp.LdMatrix8x8x16bOp(False, 4), BFloat16), tiled_mma_qk)
|
|
thr_smem_Q = smem_copy_Q.get_slice(tidx)
|
|
tCsQ = thr_smem_Q.partition_S(sQ); tCrQ_cv = thr_smem_Q.retile(tCrQ)
|
|
|
|
smem_copy_K = cute.make_tiled_copy_B(
|
|
cute.make_copy_atom(warp.LdMatrix8x8x16bOp(False, 4), BFloat16), tiled_mma_qk)
|
|
thr_smem_K = smem_copy_K.get_slice(tidx)
|
|
tCsK = thr_smem_K.partition_S(sK); tCrK_cv = thr_smem_K.retile(tCrK)
|
|
|
|
# Load Q to registers (once)
|
|
for k in cutlass.range_constexpr(cute.size(tCsQ.shape[2])):
|
|
cute.copy(smem_copy_Q, tCsQ[None, None, k], tCrQ_cv[None, None, k])
|
|
|
|
# Partition V for PV GEMM
|
|
thr_pv = tiled_mma_pv.get_slice(tidx)
|
|
sVt = cute.composition(sV, cute.make_layout((HD, KT), stride=(KT, 1)))
|
|
tOrVt = thr_pv.make_fragment_B(thr_pv.partition_B(sVt))
|
|
smem_copy_V = cute.make_tiled_copy_B(
|
|
cute.make_copy_atom(warp.LdMatrix8x8x16bOp(True, 4), BFloat16), tiled_mma_pv)
|
|
thr_smem_V = smem_copy_V.get_slice(tidx)
|
|
tOsVt = thr_smem_V.partition_S(sVt); tOrVt_cv = thr_smem_V.retile(tOrVt)
|
|
|
|
# ── KV tile loop ─────────────────────────────────
|
|
n_block_max = (self._window_size + KT - 1) // KT
|
|
|
|
for n_block in range(n_block_max):
|
|
tile_start = n_block * KT
|
|
|
|
# Load K and V to SMEM
|
|
for kv_pos in range(KT):
|
|
global_kv = tile_start + kv_pos
|
|
for d in range(HD):
|
|
val = cutlass.BFloat16(0.0)
|
|
if global_kv < swa_len:
|
|
val = mKV[tok_idx, global_kv, d]
|
|
sK[kv_pos, d] = val
|
|
sV[kv_pos, d] = val
|
|
|
|
cute.arch.sync_threads()
|
|
|
|
# ── Q @ K^T via tcgen05.mma ──────────────────
|
|
for k in cutlass.range_constexpr(cute.size(tCsK.shape[2])):
|
|
cute.copy(smem_copy_K, tCsK[None, None, k], tCrK_cv[None, None, k])
|
|
tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, k > 0)
|
|
cute.gemm(tiled_mma_qk, tCtScores, tCrQ[None, None, k], tCrK[None, None, k], tCtScores)
|
|
|
|
# Signal epilogue: scores in TMEM are ready
|
|
cute.arch.fence_view_async_tmem_store()
|
|
acc_full_barrier.arrive()
|
|
|
|
# Wait for epilogue to finish softmax
|
|
acc_full_barrier.wait()
|
|
|
|
# ── P @ V via tcgen05.mma ─────────────────────
|
|
# P from TMEM (softmax output), V from SMEM
|
|
for k in cutlass.range_constexpr(cute.size(tOsVt.shape[2])):
|
|
cute.copy(smem_copy_V, tOsVt[None, None, k], tOrVt_cv[None, None, k])
|
|
# For the first KV tile, first k tile: ACCUMULATE=False (zero the output)
|
|
# Otherwise, ACCUMULATE=True (accumulate into output)
|
|
is_first_output_tile = (n_block == 0) and (k == 0)
|
|
tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, not is_first_output_tile)
|
|
cute.gemm(tiled_mma_pv, tCtOutput, None, tOrVt[None, None, k], tCtOutput)
|
|
|
|
cute.arch.fence_view_async_tmem_store()
|
|
acc_full_barrier.arrive()
|
|
|
|
# Free TMEM
|
|
tmem.relinquish_alloc_permit()
|
|
acc_full_barrier.arrive() # Sync before dealloc
|
|
tmem.free(num_tmem_cols)
|
|
|
|
# ─── EPILOGUE WARPS ──────────────────────────────────
|
|
if is_epi_warp:
|
|
my_row_start = warp_idx * 64
|
|
num_my_rows = 64
|
|
|
|
# Online softmax state
|
|
row_max = cute.make_rmem_tensor((num_my_rows,), Float32)
|
|
row_sum = cute.make_rmem_tensor((num_my_rows,), Float32)
|
|
row_max.fill(-1e30)
|
|
row_sum.fill(0.0)
|
|
|
|
# TMEM→register copy for scores (tcgen05.ld pattern)
|
|
tiled_copy_t2r_scores = tcgen05.make_tmem_copy(
|
|
sm100_utils.get_tmem_load_op(self._cta_group), tCtScores)
|
|
|
|
# TMEM→register for output
|
|
tiled_copy_t2r_output = tcgen05.make_tmem_copy(
|
|
sm100_utils.get_tmem_load_op(self._cta_group), tCtOutput)
|
|
|
|
# Register fragments
|
|
tRrScores = cute.make_fragment_like(
|
|
tiled_copy_t2r_scores.partition_D(
|
|
cute.make_tensor(tCtScores.iterator, tCtScores.layout)), Float32)
|
|
tRrOutput = cute.make_fragment_like(
|
|
tiled_copy_t2r_output.partition_D(
|
|
cute.make_tensor(tCtOutput.iterator, tCtOutput.layout)), Float32)
|
|
|
|
n_block_max = (self._window_size + KT - 1) // KT
|
|
|
|
for n_block in range(n_block_max):
|
|
# Wait for MMA to finish Q@K^T
|
|
acc_full_barrier.wait()
|
|
|
|
# ── tcgen05.ld scores from TMEM to registers ──
|
|
cute.copy(tiled_copy_t2r_scores, tCtScores, tRrScores)
|
|
cute.arch.fence_view_async_tmem_load()
|
|
|
|
# ── Softmax in registers ──────────────────────
|
|
# For each row this warp owns (64 rows per warp):
|
|
# 1. tile_max = max(scores * scale) — reduce across KT positions
|
|
# 2. new_max = max(row_max_prev, tile_max)
|
|
# 3. prev_exp = exp((row_max_prev - new_max) * scale * log2e)
|
|
# 4. Rescale output accumulator in TMEM: O *= prev_exp
|
|
# (via tcgen05.st with scaled values, or defer to PV GEMM)
|
|
# 5. row_sum *= prev_exp
|
|
# 6. exp_scores = exp((scores * scale - new_max * scale) * log2e)
|
|
# 7. row_sum += sum(exp_scores)
|
|
# 8. Write exp_scores back to TMEM (as P operand for PV GEMM)
|
|
# 9. row_max = new_max
|
|
|
|
# The register fragment layout after tcgen05.ld is
|
|
# (EPI_TILE_M=128, EPI_TILE_N=16) partitioned per epilogue warp.
|
|
# Each epilogue warp's fragment covers 64 rows and 16 columns.
|
|
# We can iterate over rows and compute softmax per row.
|
|
|
|
# TODO: Implement per-row softmax on the register fragment.
|
|
# This requires understanding the exact tRrScores layout
|
|
# from tcgen05.make_tmem_copy + partition_D.
|
|
# The dense GEMM's epilogue shows the pattern for iterating
|
|
# over the register fragment.
|
|
|
|
# For now, write the P values (softmax output) back to TMEM
|
|
# via tcgen05.st (register → TMEM copy)
|
|
# tcgen05.st: tRrScores → tCtScores
|
|
cute.copy(tiled_copy_t2r_scores, tRrScores, tCtScores)
|
|
cute.arch.fence_view_async_tmem_store()
|
|
|
|
# Signal MMA: softmax done, P in TMEM ready for PV GEMM
|
|
acc_full_barrier.arrive()
|
|
|
|
# Wait for MMA to finish P@V
|
|
acc_full_barrier.wait()
|
|
|
|
# ── Final normalize ──────────────────────────────
|
|
# tcgen05.ld output from TMEM, divide by row_sum, store to GMEM
|
|
cute.copy(tiled_copy_t2r_output, tCtOutput, tRrOutput)
|
|
cute.arch.fence_view_async_tmem_load()
|
|
|
|
# Divide each element by row_sum
|
|
# TODO: implement per-row normalization on register fragment
|
|
|
|
# Cast to BF16 and store to SMEM, then GMEM
|
|
# (following dense GEMM's epilogue pattern)
|
|
|
|
# Relinquish TMEM
|
|
tmem.relinquish_alloc_permit()
|
|
acc_full_barrier.arrive()
|
|
tmem.free(num_tmem_cols)
|