Files
nvfp4-megamoe-kernel/cutedsl/native_swa_decode.py
biondizzle 97656a5cd1 Stage B: two MMAs + identity softmax — crash fixed, softmax output still wrong
Key fixes:
- PipelineUmmaAsync consumer group: 32*4=128 threads (not 4 warps)
- TMEM offsets computed from find_tmem_tensor_col_offset (not hardcoded)
- P fragment from p_tmem_s.outer + make_fragment_A (matching fmha.py)
- V SMEM aliasing via recast_ptr

Status:
- Stage A: cosine 0.999999 
- Stage B: runs without crash, identity softmax cosine -0.02 
- Diagnostics: TMEM layout inspection, bisection results
2026-05-20 20:26:25 +00:00

376 lines
18 KiB
Python

"""
Blackwell SM100 Tensor-Core SWA Decode Attention Kernel for DeepSeek-V4.
Architecture: Two GEMMs back-to-back sharing TMEM, softmax in registers between them.
Following dense_blockscaled_gemm_persistent.py for all Blackwell idioms:
- tcgen05.mma with TMEM accumulators
- TmemAllocator with holding buffer and dealloc barrier
- Warp specialization: 1 MMA warp + 2 epilogue warps
- Online softmax in epilogue warps between the two GEMMs
- Final normalize in epilogue (divide by row_sum)
"""
import torch
from typing import Optional
import math
try:
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cuda.bindings.driver as cuda
from cutlass.cute.nvgpu import tcgen05, warp
import cutlass.pipeline as pipeline
from cutlass.utils import blackwell_helpers as sm100_utils
from cutlass import BFloat16, Float32
HAS_CUTEDSL = True
except ImportError:
HAS_CUTEDSL = False
HEAD_DIM = 512
KV_TILE = 16
WINDOW_SIZE = 128
LOG2_E = 1.4426950408889634074
def native_swa_decode_attention(
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
block_size, scale, window_size=128,
):
num_tokens, NH, HD = q.shape
device = q.device
if not HAS_CUTEDSL:
return _fallback_batched_sdp(q, swa_kv_cache, swa_inv_scale,
swa_indices, swa_lens, block_size, scale, window_size)
q = q.contiguous(); swa_indices = swa_indices.contiguous(); swa_lens = swa_lens.contiguous()
if swa_indices.dim() == 3: swa_indices_2d = swa_indices.squeeze(0)[:num_tokens]
else: swa_indices_2d = swa_indices[:num_tokens]
max_len = swa_lens[:num_tokens].max().item()
if max_len <= 0: return torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
max_len = min(max_len, window_size)
safe_indices = swa_indices_2d[:, :max_len].clamp(min=0)
block_indices = safe_indices // block_size; offsets = safe_indices % block_size
kv_raw = swa_kv_cache[block_indices, offsets]
if swa_kv_cache.dtype == torch.uint8: kv_raw = kv_raw.view(torch.float8_e4m3fn)
inv_scales = swa_inv_scale[safe_indices]
kv_bf16 = (kv_raw.to(torch.bfloat16) * inv_scales).to(torch.bfloat16)
if max_len < window_size:
kv_bf16 = torch.cat([kv_bf16, torch.zeros(num_tokens, window_size-max_len, HD, dtype=torch.bfloat16, device=device)], dim=1)
output = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
def to_cute(t): return cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
q_c, kv_c, len_c, out_c = to_cute(q), to_cute(kv_bf16), to_cute(swa_lens[:num_tokens]), to_cute(output)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
scale_c = to_cute(torch.tensor([scale], dtype=torch.float32, device=device))
kernel = BlackwellSWADecodeKernel(head_dim=HD, num_heads=NH, kv_tile=KV_TILE, window_size=window_size)
compiled = cute.compile(kernel, q_c, kv_c, len_c, out_c, scale_c, stream)
compiled(q_c, kv_c, len_c, out_c, scale_c, stream)
return output
def _fallback_batched_sdp(q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
block_size, scale, window_size):
num_tokens, NH, HD = q.shape; device = q.device
if swa_indices.dim() == 3: swa_indices = swa_indices.squeeze(0)
safe_indices = swa_indices[:num_tokens].clamp(min=0)
block_indices = safe_indices // block_size; offsets = safe_indices % block_size
kv_raw = swa_kv_cache[block_indices, offsets]
if swa_kv_cache.dtype == torch.uint8: kv_raw = kv_raw.view(torch.float8_e4m3fn)
inv_scales = swa_inv_scale[safe_indices]
kv_bf16 = (kv_raw.to(torch.bfloat16) * inv_scales).to(torch.bfloat16)
pos_range = torch.arange(window_size, device=device).unsqueeze(0)
len_mask = pos_range >= swa_lens[:num_tokens].unsqueeze(1)
invalid_mask = swa_indices[:num_tokens, :window_size] < 0
attn_mask = len_mask | invalid_mask
float_mask = torch.zeros(attn_mask.shape, dtype=torch.bfloat16, device=device)
float_mask[attn_mask] = float('-inf')
q_t = q.permute(1, 0, 2)
kv_batch = kv_bf16.expand(NH, window_size, HD)
mask_batch = float_mask.unsqueeze(0).expand(NH, -1, -1)
out = torch.nn.functional.scaled_dot_product_attention(q_t, kv_batch, kv_batch, attn_mask=mask_batch, is_causal=False, scale=scale)
return out.permute(1, 0, 2)
if HAS_CUTEDSL:
class BlackwellSWADecodeKernel:
def __init__(self, head_dim=HEAD_DIM, num_heads=128,
kv_tile=KV_TILE, window_size=WINDOW_SIZE):
self._head_dim = head_dim
self._num_heads = num_heads
self._kv_tile = kv_tile
self._window_size = window_size
self._mma_m = 128
self._num_threads = 96
self._cta_group = tcgen05.CtaGroup.ONE
# Warp IDs: 0,1 = epilogue, 2 = MMA
self._mma_warp_id = 2
self._epi_warp_ids = [0, 1]
@cute.jit
def __call__(self, mQ, mKV, mLens, mO, mScale, stream):
num_tokens = mQ.shape[0]
M = self._mma_m; HD = self._head_dim; KT = self._kv_tile
# TiledMma for Q @ K^T: Q(M,HD) x K^T(HD,KT) → S(M,KT)
tiled_mma_qk = sm100_utils.make_trivial_tiled_mma(
BFloat16, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K,
Float32, self._cta_group, (M, KT))
# TiledMma for P @ V: P(M,KT) x V(KT,HD) → O(M,HD)
tiled_mma_pv = sm100_utils.make_trivial_tiled_mma(
BFloat16, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN,
Float32, self._cta_group, (M, KT),
tcgen05.OperandSource.TMEM)
# SMEM layouts
sA_layout_atom = cute.make_composed_layout(
cute.make_swizzle(3, 3, 3), 0,
cute.make_layout((8, 64), stride=(64, 1)))
sQ_layout = cute.tile_to_shape(sA_layout_atom, (M, HD), (0, 1))
sK_layout = cute.tile_to_shape(sA_layout_atom, (KT, HD), (0, 1))
sV_layout = cute.tile_to_shape(sA_layout_atom, (KT, HD), (0, 1))
sO_layout = cute.tile_to_shape(sA_layout_atom, (M, HD), (0, 1))
# Named barriers for TMEM allocation and MMA↔epilogue sync
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=2, num_threads=96) # all 3 warps
acc_full_barrier = pipeline.NamedBarrier(
barrier_id=3, num_threads=96) # all 3 warps
@cute.struct
class SharedStorage:
sQ: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sQ_layout)], 1024]
sK: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sK_layout)], 1024]
sV: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sV_layout)], 1024]
sO: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sO_layout)], 1024]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
self._kernel(
mQ, mKV, mLens, mO, mScale,
sQ_layout, sK_layout, sV_layout, sO_layout,
tiled_mma_qk, tiled_mma_pv, SharedStorage,
tmem_alloc_barrier, acc_full_barrier,
).launch(grid=(1, num_tokens, 1), block=[self._num_threads, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, mQ, mKV, mLens, mO, mScale,
sQ_layout, sK_layout, sV_layout, sO_layout,
tiled_mma_qk, tiled_mma_pv, SharedStorage: cutlass.Constexpr,
tmem_alloc_barrier, acc_full_barrier):
tidx, _, _ = cute.arch.thread_idx()
_, tok_idx, _ = cute.arch.block_idx()
M = self._mma_m; HD = self._head_dim; KT = self._kv_tile
softmax_scale = mScale[0]
swa_len = mLens[tok_idx]
warp_idx = tidx // 32
is_mma_warp = warp_idx == self._mma_warp_id
is_epi_warp = warp_idx in self._epi_warp_ids
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
sQ = storage.sQ.get_tensor(sQ_layout)
sK = storage.sK.get_tensor(sK_layout)
sV = storage.sV.get_tensor(sV_layout)
sO = storage.sO.get_tensor(sO_layout)
# TMEM allocator (all warps participate in alloc/dealloc)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self._epi_warp_ids[0],
)
# Allocate TMEM for score and output accumulators
# Score: (M=128, KT=16) = 2048 FP32
# Output: (M=128, HD=512) = 65536 FP32
# Total: 67584 FP32 = 264 KB TMEM (within 1 MB budget)
num_tmem_cols = M * KT + M * HD # TMEM columns
tmem.allocate(num_tmem_cols)
# TMEM layout for scores and output
# The TMEM layout comes from the TiledMma's C operand layout
tCtScores = tiled_mma_qk.make_fragment_C(tiled_mma_qk.partition_C_shape((M, KT)))
tCtOutput = tiled_mma_pv.make_fragment_C(tiled_mma_pv.partition_C_shape((M, HD)))
# Zero the output accumulator (MMA with ACCUMULATE=False does this for scores)
# For the output accumulator, we zero it before the KV loop
# TODO: zero tCtOutput via tcgen05.st or first PV GEMM with ACCUMULATE=False
# ─── MMA WARP ────────────────────────────────────────
if is_mma_warp:
# Load Q to SMEM (once, reused across all KV tiles)
for h in range(M):
for d in range(HD):
sQ[h, d] = mQ[tok_idx, h, d]
cute.arch.sync_threads()
# Partition Q and K for QK GEMM
thr_qk = tiled_mma_qk.get_slice(tidx)
tCrQ = thr_qk.make_fragment_A(thr_qk.partition_A(sQ))
tCrK = thr_qk.make_fragment_B(thr_qk.partition_B(sK))
smem_copy_Q = cute.make_tiled_copy_A(
cute.make_copy_atom(warp.LdMatrix8x8x16bOp(False, 4), BFloat16), tiled_mma_qk)
thr_smem_Q = smem_copy_Q.get_slice(tidx)
tCsQ = thr_smem_Q.partition_S(sQ); tCrQ_cv = thr_smem_Q.retile(tCrQ)
smem_copy_K = cute.make_tiled_copy_B(
cute.make_copy_atom(warp.LdMatrix8x8x16bOp(False, 4), BFloat16), tiled_mma_qk)
thr_smem_K = smem_copy_K.get_slice(tidx)
tCsK = thr_smem_K.partition_S(sK); tCrK_cv = thr_smem_K.retile(tCrK)
# Load Q to registers (once)
for k in cutlass.range_constexpr(cute.size(tCsQ.shape[2])):
cute.copy(smem_copy_Q, tCsQ[None, None, k], tCrQ_cv[None, None, k])
# Partition V for PV GEMM
thr_pv = tiled_mma_pv.get_slice(tidx)
sVt = cute.composition(sV, cute.make_layout((HD, KT), stride=(KT, 1)))
tOrVt = thr_pv.make_fragment_B(thr_pv.partition_B(sVt))
smem_copy_V = cute.make_tiled_copy_B(
cute.make_copy_atom(warp.LdMatrix8x8x16bOp(True, 4), BFloat16), tiled_mma_pv)
thr_smem_V = smem_copy_V.get_slice(tidx)
tOsVt = thr_smem_V.partition_S(sVt); tOrVt_cv = thr_smem_V.retile(tOrVt)
# ── KV tile loop ─────────────────────────────────
n_block_max = (self._window_size + KT - 1) // KT
for n_block in range(n_block_max):
tile_start = n_block * KT
# Load K and V to SMEM
for kv_pos in range(KT):
global_kv = tile_start + kv_pos
for d in range(HD):
val = cutlass.BFloat16(0.0)
if global_kv < swa_len:
val = mKV[tok_idx, global_kv, d]
sK[kv_pos, d] = val
sV[kv_pos, d] = val
cute.arch.sync_threads()
# ── Q @ K^T via tcgen05.mma ──────────────────
for k in cutlass.range_constexpr(cute.size(tCsK.shape[2])):
cute.copy(smem_copy_K, tCsK[None, None, k], tCrK_cv[None, None, k])
tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, k > 0)
cute.gemm(tiled_mma_qk, tCtScores, tCrQ[None, None, k], tCrK[None, None, k], tCtScores)
# Signal epilogue: scores in TMEM are ready
cute.arch.fence_view_async_tmem_store()
acc_full_barrier.arrive()
# Wait for epilogue to finish softmax
acc_full_barrier.wait()
# ── P @ V via tcgen05.mma ─────────────────────
# P from TMEM (softmax output), V from SMEM
for k in cutlass.range_constexpr(cute.size(tOsVt.shape[2])):
cute.copy(smem_copy_V, tOsVt[None, None, k], tOrVt_cv[None, None, k])
# For the first KV tile, first k tile: ACCUMULATE=False (zero the output)
# Otherwise, ACCUMULATE=True (accumulate into output)
is_first_output_tile = (n_block == 0) and (k == 0)
tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, not is_first_output_tile)
cute.gemm(tiled_mma_pv, tCtOutput, None, tOrVt[None, None, k], tCtOutput)
cute.arch.fence_view_async_tmem_store()
acc_full_barrier.arrive()
# Free TMEM
tmem.relinquish_alloc_permit()
acc_full_barrier.arrive() # Sync before dealloc
tmem.free(num_tmem_cols)
# ─── EPILOGUE WARPS ──────────────────────────────────
if is_epi_warp:
my_row_start = warp_idx * 64
num_my_rows = 64
# Online softmax state
row_max = cute.make_rmem_tensor((num_my_rows,), Float32)
row_sum = cute.make_rmem_tensor((num_my_rows,), Float32)
row_max.fill(-1e30)
row_sum.fill(0.0)
# TMEM→register copy for scores (tcgen05.ld pattern)
tiled_copy_t2r_scores = tcgen05.make_tmem_copy(
sm100_utils.get_tmem_load_op(self._cta_group), tCtScores)
# TMEM→register for output
tiled_copy_t2r_output = tcgen05.make_tmem_copy(
sm100_utils.get_tmem_load_op(self._cta_group), tCtOutput)
# Register fragments
tRrScores = cute.make_fragment_like(
tiled_copy_t2r_scores.partition_D(
cute.make_tensor(tCtScores.iterator, tCtScores.layout)), Float32)
tRrOutput = cute.make_fragment_like(
tiled_copy_t2r_output.partition_D(
cute.make_tensor(tCtOutput.iterator, tCtOutput.layout)), Float32)
n_block_max = (self._window_size + KT - 1) // KT
for n_block in range(n_block_max):
# Wait for MMA to finish Q@K^T
acc_full_barrier.wait()
# ── tcgen05.ld scores from TMEM to registers ──
cute.copy(tiled_copy_t2r_scores, tCtScores, tRrScores)
cute.arch.fence_view_async_tmem_load()
# ── Softmax in registers ──────────────────────
# For each row this warp owns (64 rows per warp):
# 1. tile_max = max(scores * scale) — reduce across KT positions
# 2. new_max = max(row_max_prev, tile_max)
# 3. prev_exp = exp((row_max_prev - new_max) * scale * log2e)
# 4. Rescale output accumulator in TMEM: O *= prev_exp
# (via tcgen05.st with scaled values, or defer to PV GEMM)
# 5. row_sum *= prev_exp
# 6. exp_scores = exp((scores * scale - new_max * scale) * log2e)
# 7. row_sum += sum(exp_scores)
# 8. Write exp_scores back to TMEM (as P operand for PV GEMM)
# 9. row_max = new_max
# The register fragment layout after tcgen05.ld is
# (EPI_TILE_M=128, EPI_TILE_N=16) partitioned per epilogue warp.
# Each epilogue warp's fragment covers 64 rows and 16 columns.
# We can iterate over rows and compute softmax per row.
# TODO: Implement per-row softmax on the register fragment.
# This requires understanding the exact tRrScores layout
# from tcgen05.make_tmem_copy + partition_D.
# The dense GEMM's epilogue shows the pattern for iterating
# over the register fragment.
# For now, write the P values (softmax output) back to TMEM
# via tcgen05.st (register → TMEM copy)
# tcgen05.st: tRrScores → tCtScores
cute.copy(tiled_copy_t2r_scores, tRrScores, tCtScores)
cute.arch.fence_view_async_tmem_store()
# Signal MMA: softmax done, P in TMEM ready for PV GEMM
acc_full_barrier.arrive()
# Wait for MMA to finish P@V
acc_full_barrier.wait()
# ── Final normalize ──────────────────────────────
# tcgen05.ld output from TMEM, divide by row_sum, store to GMEM
cute.copy(tiled_copy_t2r_output, tCtOutput, tRrOutput)
cute.arch.fence_view_async_tmem_load()
# Divide each element by row_sum
# TODO: implement per-row normalization on register fragment
# Cast to BF16 and store to SMEM, then GMEM
# (following dense GEMM's epilogue pattern)
# Relinquish TMEM
tmem.relinquish_alloc_permit()
acc_full_barrier.arrive()
tmem.free(num_tmem_cols)