feat: add native CuTeDSL SWA decode attention kernel stub + batched SDPA fallback
This commit is contained in:
271
cutedsl/native_swa_decode.py
Normal file
271
cutedsl/native_swa_decode.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Native CuTeDSL SWA Decode Attention Kernel for DeepSeek-V4 on Blackwell (SM100).
|
||||
|
||||
This is a FUSED kernel that replaces the Python for-loop decode path.
|
||||
|
||||
Decode attention: each token has 1 query (1, NH, HD) attending to up to window_size
|
||||
KV entries from the paged fp8 cache. Since K=V in MLA, this is a batched GEMV.
|
||||
|
||||
The kernel fuses:
|
||||
1. Paged KV read (using pre-computed swa_indices)
|
||||
2. fp8 dequantize (fp8 * inv_scale → bf16)
|
||||
3. Q×K^T (GEMV, not GEMM — 1 query vs N KVs)
|
||||
4. Online softmax (max + exp + sum)
|
||||
5. Weighted V accumulation (softmax_weights × V)
|
||||
6. Output write
|
||||
|
||||
CTA mapping: one CTA per (decode_token, q_head_group).
|
||||
- With 128 Q heads and 16 heads per group, that's 8 groups per token.
|
||||
- Each CTA handles 16 Q heads sharing the same KV.
|
||||
- Grid: (num_head_groups, num_decode_tokens, 1)
|
||||
|
||||
Tiling:
|
||||
- Q: (HEAD_GROUP, HD) per CTA — loaded once into registers
|
||||
- K/V: streamed in tiles of KV_TILE (e.g., 16) tokens from smem
|
||||
- smem holds one K tile: (KV_TILE, HD) in bf16 = 16 * 512 * 2 = 16 KB
|
||||
- fp8 → bf16 dequantize happens during smem load
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
# For now, this is the host-side launch wrapper that calls the CuTeDSL kernel.
|
||||
# The kernel itself is below and will be compiled with cute.compile.
|
||||
|
||||
try:
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import tcgen05, warp
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.utils as utils
|
||||
import cuda.bindings.driver as cuda
|
||||
HAS_CUTEDSL = True
|
||||
except ImportError:
|
||||
HAS_CUTEDSL = False
|
||||
|
||||
|
||||
# ── Host-side wrapper ─────────────────────────────────────────────────
|
||||
|
||||
def native_swa_decode_attention(
|
||||
q: torch.Tensor, # (T, NH, HD) bf16, with RoPE
|
||||
swa_kv_cache: torch.Tensor, # (num_blocks, block_size, HD) fp8 (uint8) paged cache
|
||||
swa_inv_scale: torch.Tensor, # (max_slots, 1) bf16 per-token inv scale
|
||||
swa_indices: torch.Tensor, # (T, window_size) int64 slot indices
|
||||
swa_lens: torch.Tensor, # (T,) int64 valid lengths
|
||||
block_size: int, # tokens per block (256)
|
||||
scale: float, # 1/sqrt(HD)
|
||||
window_size: int = 128, # sliding window size
|
||||
) -> torch.Tensor:
|
||||
"""Native SWA decode attention — calls the CuTeDSL kernel.
|
||||
|
||||
Falls back to optimized PyTorch batched SDPA if CuTeDSL is not available
|
||||
or if the kernel hasn't been compiled yet.
|
||||
"""
|
||||
num_tokens, NH, HD = q.shape
|
||||
device = q.device
|
||||
|
||||
if not HAS_CUTEDSL:
|
||||
return _fallback_batched_sdp(q, swa_kv_cache, swa_inv_scale,
|
||||
swa_indices, swa_lens, block_size,
|
||||
scale, window_size)
|
||||
|
||||
# TODO: Implement CuTeDSL kernel launch
|
||||
# For now, use the optimized PyTorch fallback
|
||||
return _fallback_batched_sdp(q, swa_kv_cache, swa_inv_scale,
|
||||
swa_indices, swa_lens, block_size,
|
||||
scale, window_size)
|
||||
|
||||
|
||||
def _fallback_batched_sdp(
|
||||
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
|
||||
block_size, scale, window_size,
|
||||
):
|
||||
"""Optimized PyTorch batched SDPA — no Python for-loop.
|
||||
|
||||
This is the fallback when the CuTeDSL kernel isn't compiled yet.
|
||||
All decode tokens are processed in a single batched SDPA call:
|
||||
1. Gather ALL KV entries for ALL decode tokens at once
|
||||
2. Dequantize fp8 → bf16 in one batch
|
||||
3. Run batched SDPA with proper masking
|
||||
"""
|
||||
num_tokens, NH, HD = q.shape
|
||||
device = q.device
|
||||
|
||||
safe_indices = swa_indices[:num_tokens].clamp(min=0)
|
||||
block_indices = safe_indices // block_size
|
||||
offsets = safe_indices % block_size
|
||||
|
||||
# Batched KV gather + dequant
|
||||
kv_raw = swa_kv_cache[block_indices, offsets]
|
||||
if swa_kv_cache.dtype == torch.uint8:
|
||||
kv_raw = kv_raw.view(torch.float8_e4m3fn)
|
||||
inv_scales = swa_inv_scale[safe_indices]
|
||||
kv_bf16 = (kv_raw.to(torch.bfloat16) * inv_scales).to(torch.bfloat16)
|
||||
|
||||
# Attention mask
|
||||
pos_range = torch.arange(window_size, device=device).unsqueeze(0)
|
||||
len_mask = pos_range >= swa_lens[:num_tokens].unsqueeze(1)
|
||||
invalid_mask = swa_indices[:num_tokens] < 0
|
||||
attn_mask = len_mask | invalid_mask
|
||||
float_mask = torch.zeros(attn_mask.shape, dtype=torch.bfloat16, device=device)
|
||||
float_mask[attn_mask] = float('-inf')
|
||||
|
||||
# Batched SDPA
|
||||
q_t = q.permute(1, 0, 2)
|
||||
q_batch = q_t.reshape(NH * num_tokens, 1, HD)
|
||||
kv_expanded = kv_bf16.unsqueeze(0).expand(NH, -1, -1, -1)
|
||||
k_batch = kv_expanded.reshape(NH * num_tokens, window_size, HD)
|
||||
v_batch = k_batch
|
||||
|
||||
mask_batch = float_mask.unsqueeze(0).unsqueeze(2).expand(
|
||||
NH, num_tokens, 1, window_size
|
||||
).reshape(NH * num_tokens, 1, window_size)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_batch, k_batch, v_batch,
|
||||
attn_mask=mask_batch,
|
||||
is_causal=False,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
return out.reshape(NH, num_tokens, HD).permute(1, 0, 2)
|
||||
|
||||
|
||||
# ── CuTeDSL Kernel (Blackwell SM100) ──────────────────────────────────
|
||||
|
||||
# The kernel below implements the full fused decode attention on Blackwell.
|
||||
# It uses tcgen05 (Blackwell tensor core) for the GEMV operations.
|
||||
#
|
||||
# Architecture per CTA:
|
||||
# - 1 CTA = 1 warp group (128 threads) handling (1 token, HEAD_GROUP heads)
|
||||
# - Q: (HEAD_GROUP, HD) loaded once into registers
|
||||
# - K/V: streamed in tiles of KV_TILE tokens
|
||||
# - fp8 dequant: fused during gmem→smem copy
|
||||
# - Online softmax: row_max, row_exp_sum tracked in registers
|
||||
# - Output: (HEAD_GROUP, HD) accumulated in registers, written to gmem at end
|
||||
|
||||
HEAD_GROUP = 16 # Q heads per CTA (128 total heads / 8 CTAs)
|
||||
KV_TILE = 16 # KV tokens per smem tile
|
||||
HEAD_DIM = 512 # KV latent dimension
|
||||
|
||||
|
||||
if HAS_CUTEDSL:
|
||||
|
||||
class BlackwellSWADecodeAttention:
|
||||
"""CuTeDSL SWA Decode Attention Kernel for Blackwell.
|
||||
|
||||
Each CTA handles one (decode_token, q_head_group) pair.
|
||||
The kernel streams KV tiles from the paged fp8 cache,
|
||||
dequantizes, computes Q×K^T, online softmax, and V accumulation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int = HEAD_DIM,
|
||||
head_group: int = HEAD_GROUP,
|
||||
kv_tile: int = KV_TILE,
|
||||
num_threads: int = 128,
|
||||
):
|
||||
self._head_dim = head_dim
|
||||
self._head_group = head_group
|
||||
self._kv_tile = kv_tile
|
||||
self._num_threads = num_threads
|
||||
self._head_dim_padded = (head_dim + 31) // 32 * 32
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
mQ: cute.Tensor, # (T, NH, HD)
|
||||
mKV_cache: cute.Tensor, # (num_blocks, block_size, HD) uint8
|
||||
mInv_scale: cute.Tensor, # (max_slots, 1) bf16
|
||||
mSwa_indices: cute.Tensor, # (T, W) int64
|
||||
mSwa_lens: cute.Tensor, # (T,) int64
|
||||
mO: cute.Tensor, # (T, NH, HD)
|
||||
softmax_scale: cutlass.Float32,
|
||||
window_size: int,
|
||||
block_size: int,
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
# Grid: (num_head_groups, num_decode_tokens, 1)
|
||||
num_head_groups = mQ.shape[1] // self._head_group
|
||||
num_decode_tokens = mQ.shape[0]
|
||||
grid_dim = (num_head_groups, num_decode_tokens, 1)
|
||||
|
||||
self.kernel(
|
||||
mQ, mKV_cache, mInv_scale, mSwa_indices, mSwa_lens, mO,
|
||||
softmax_scale, window_size, block_size,
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
block=[self._num_threads, 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
mQ: cute.Tensor,
|
||||
mKV_cache: cute.Tensor,
|
||||
mInv_scale: cute.Tensor,
|
||||
mSwa_indices: cute.Tensor,
|
||||
mSwa_lens: cute.Tensor,
|
||||
mO: cute.Tensor,
|
||||
softmax_scale: cutlass.Float32,
|
||||
window_size: int,
|
||||
block_size: int,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
head_group_idx, token_idx, _ = cute.arch.block_idx()
|
||||
|
||||
# This CTA handles Q heads [head_group_idx * HEAD_GROUP : (head_group_idx + 1) * HEAD_GROUP]
|
||||
# for decode token token_idx
|
||||
|
||||
# Read swa_len for this token
|
||||
swa_len = mSwa_lens[token_idx]
|
||||
|
||||
# Load Q for this (token, head_group)
|
||||
# Q shape: (HEAD_GROUP, HD) from mQ[token_idx, head_group_idx*HEAD_GROUP:(+HEAD_GROUP), :]
|
||||
q_local = cute.make_rmem_tensor(
|
||||
(self._head_group, self._head_dim), cutlass.BFloat16
|
||||
)
|
||||
for h in cutlass.range_constexpr(self._head_group):
|
||||
for d in range(self._head_dim):
|
||||
q_local[h, d] = mQ[token_idx, head_group_idx * self._head_group + h, d]
|
||||
|
||||
# Accumulator for output: (HEAD_GROUP, HD) in float32
|
||||
acc_O = cute.make_rmem_tensor(
|
||||
(self._head_group, self._head_dim), cutlass.Float32
|
||||
)
|
||||
acc_O.fill(0.0)
|
||||
|
||||
# Online softmax state: (HEAD_GROUP,) max and sum
|
||||
row_max = cute.make_rmem_tensor((self._head_group,), cutlass.Float32)
|
||||
row_sum = cute.make_rmem_tensor((self._head_group,), cutlass.Float32)
|
||||
row_max.fill(-cutlass.Float32.inf)
|
||||
row_sum.fill(0.0)
|
||||
|
||||
# Stream KV tiles
|
||||
num_kv_tiles = cute.ceil_div(swa_len, self._kv_tile)
|
||||
|
||||
for kv_tile_idx in range(num_kv_tiles):
|
||||
# 1. Read swa_indices for this tile
|
||||
# 2. Gather KV from paged cache
|
||||
# 3. Dequantize fp8 → bf16
|
||||
# 4. Compute Q × K^T
|
||||
# 5. Update online softmax
|
||||
# 6. Accumulate weighted V
|
||||
pass # TODO: implement tile processing
|
||||
|
||||
# Normalize output by row_sum
|
||||
for h in cutlass.range_constexpr(self._head_group):
|
||||
if row_sum[h] != 0.0:
|
||||
inv_sum = 1.0 / row_sum[h]
|
||||
for d in range(self._head_dim):
|
||||
acc_O[h, d] = acc_O[h, d] * inv_sum
|
||||
|
||||
# Write output
|
||||
for h in cutlass.range_constexpr(self._head_group):
|
||||
for d in range(self._head_dim):
|
||||
mO[token_idx, head_group_idx * self._head_group + h, d] = acc_O[h, d].to(cutlass.BFloat16)
|
||||
Reference in New Issue
Block a user