Files
nvfp4-megamoe-kernel/dsv4/ops/decode_sparse.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

354 lines
15 KiB
Python

"""
Native CuTeDSL Sparse SWA Decode Attention for DeepSeek-V4 on Blackwell (SM100).
Handles CSA (C4A, compress_ratio=4) and HCA (C128A, compress_ratio=128).
Attends to BOTH the SWA window AND top-k compressed KV, merged with sink weights.
Sink weight merge (FlashMLA formula):
o = exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
where o_sparse = sum(exp(s)*v) / sum(exp(s)) from compressed KV
o_swa = sum(exp(s)*v) / sum(exp(s)) from SWA KV
lse_sparse = log(sum(exp(s))) from compressed KV
lse_swa = log(sum(exp(s))) from SWA KV
attn_sink = per-head learnable parameter (NH,)
"""
import torch
import torch.nn.functional as F
from typing import Optional
try:
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cuda.bindings.driver as cuda
HAS_CUTEDSL = True
except ImportError:
HAS_CUTEDSL = False
_compiled_sparse_kernel_cache = {}
HEAD_GROUP = 16
KV_TILE = 16
HEAD_DIM = 512
NUM_THREADS = 128
def native_sparse_decode_attention(
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
compressed_kv_cache, compressed_inv_scale, topk_indices, topk_lens,
attn_sink,
block_size, scale, window_size=128, compress_ratio=4,
):
num_tokens, NH, HD = q.shape
device = q.device
if not HAS_CUTEDSL:
return _fallback_sparse_sdp(
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
compressed_kv_cache, compressed_inv_scale, topk_indices, topk_lens,
attn_sink, block_size, scale, window_size,
)
q = q.contiguous()
swa_indices = swa_indices.contiguous()
swa_lens = swa_lens.contiguous()
topk_indices = topk_indices.contiguous()
topk_lens = topk_lens.contiguous()
# Pre-dequantize SWA KV
swa_len_max = min(swa_lens[:num_tokens].max().item(), window_size)
topk_max = topk_indices.shape[-1] if topk_indices.dim() > 1 else 1
topk_len_max = min(topk_lens[:num_tokens].max().item(), topk_max) if topk_max > 0 else 0
if swa_len_max <= 0 and topk_len_max <= 0:
return torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
# Dequantize SWA KV
safe_swa = swa_indices[:num_tokens, :swa_len_max].clamp(min=0)
swa_bi = safe_swa // block_size
swa_of = safe_swa % block_size
swa_raw = swa_kv_cache[swa_bi, swa_of]
if swa_kv_cache.dtype == torch.uint8:
swa_raw = swa_raw.view(torch.float8_e4m3fn)
swa_bf16 = (swa_raw.to(torch.bfloat16) * swa_inv_scale[safe_swa]).to(torch.bfloat16)
if swa_len_max < window_size:
swa_bf16 = torch.cat([swa_bf16, torch.zeros(num_tokens, window_size - swa_len_max, HD, dtype=torch.bfloat16, device=device)], dim=1)
# Dequantize compressed KV
if topk_len_max > 0:
comp_bs = compressed_kv_cache.shape[1]
safe_topk = topk_indices[:num_tokens, :topk_len_max].clamp(min=0)
comp_bi = safe_topk // comp_bs
comp_of = safe_topk % comp_bs
comp_raw = compressed_kv_cache[comp_bi, comp_of]
if compressed_kv_cache.dtype == torch.uint8:
comp_raw = comp_raw.view(torch.float8_e4m3fn)
comp_bf16 = (comp_raw.to(torch.bfloat16) * compressed_inv_scale[safe_topk]).to(torch.bfloat16)
if topk_len_max < topk_max:
comp_bf16 = torch.cat([comp_bf16, torch.zeros(num_tokens, topk_max - topk_len_max, HD, dtype=torch.bfloat16, device=device)], dim=1)
else:
topk_max = 0
comp_bf16 = torch.zeros(num_tokens, 0, HD, dtype=torch.bfloat16, device=device)
# Combined KV: (T, window_size + topk_max, HD)
if topk_max > 0:
kv_combined = torch.cat([swa_bf16, comp_bf16], dim=1)
else:
kv_combined = swa_bf16
combined_lens = swa_lens[:num_tokens] + topk_lens[:num_tokens]
total_len = window_size + topk_max
output = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
cache_key = (num_tokens, NH, HD, window_size, topk_max, compress_ratio, str(device))
if cache_key not in _compiled_sparse_kernel_cache:
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
q_c = to_cute(q)
kv_c = to_cute(kv_combined)
len_c = to_cute(combined_lens)
out_c = to_cute(output)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
scale_tensor = torch.tensor([scale], dtype=torch.float32, device=device)
scale_c = to_cute(scale_tensor)
kernel = BlackwellSparseDecodeKernel(
head_dim=HD, head_group=HEAD_GROUP, kv_tile=KV_TILE,
total_len=total_len,
)
compiled = cute.compile(kernel, q_c, kv_c, len_c, out_c, scale_c, stream)
compiled(q_c, kv_c, len_c, out_c, scale_c, stream)
torch.cuda.synchronize()
_compiled_sparse_kernel_cache[cache_key] = {'compiled': compiled}
entry = _compiled_sparse_kernel_cache[cache_key]
compiled = entry['compiled']
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
q_c = to_cute(q)
kv_c = to_cute(kv_combined)
len_c = to_cute(combined_lens)
out_c = to_cute(output)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
scale_tensor = torch.tensor([scale], dtype=torch.float32, device=device)
scale_c = to_cute(scale_tensor)
compiled(q_c, kv_c, len_c, out_c, scale_c, stream)
return output
def _fallback_sparse_sdp(
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
compressed_kv_cache, compressed_inv_scale, topk_indices, topk_lens,
attn_sink, block_size, scale, window_size,
):
num_tokens, NH, HD = q.shape
device = q.device
if swa_indices.dim() == 3:
swa_indices = swa_indices.squeeze(0)
safe_swa = swa_indices[:num_tokens].clamp(min=0)
swa_bi = safe_swa // block_size
swa_of = safe_swa % block_size
swa_raw = swa_kv_cache[swa_bi, swa_of]
if swa_kv_cache.dtype == torch.uint8:
swa_raw = swa_raw.view(torch.float8_e4m3fn)
swa_bf16 = (swa_raw.to(torch.bfloat16) * swa_inv_scale[safe_swa]).to(torch.bfloat16)
# SWA attention (batched)
pos_range = torch.arange(window_size, device=device).unsqueeze(0)
len_mask = pos_range >= swa_lens[:num_tokens].unsqueeze(1)
invalid_mask = swa_indices[:num_tokens] < 0
attn_mask_swa = len_mask | invalid_mask
float_mask = torch.zeros(attn_mask_swa.shape, dtype=torch.bfloat16, device=device)
float_mask[attn_mask_swa] = float('-inf')
q_t = q.permute(1, 0, 2)
q_batch = q_t.reshape(NH * num_tokens, 1, HD)
kv_exp = swa_bf16.unsqueeze(0).expand(NH, num_tokens, window_size, HD)
k_batch = kv_exp.reshape(NH * num_tokens, window_size, HD)
mask_batch = float_mask.unsqueeze(0).unsqueeze(2).expand(NH, num_tokens, 1, window_size).reshape(NH * num_tokens, 1, window_size)
o_swa = F.scaled_dot_product_attention(q_batch, k_batch, k_batch, attn_mask=mask_batch, is_causal=False, scale=scale)
o_swa = o_swa.reshape(NH, num_tokens, HD).permute(1, 0, 2)
# Compute SWA lse manually
scores_swa = torch.matmul(q_batch, k_batch.transpose(-2, -1)) * scale
scores_swa = scores_swa + mask_batch.float()
max_swa = scores_swa.max(dim=-1).values # (NH*T,)
lse_swa = (max_swa + (scores_swa - max_swa.unsqueeze(-1)).exp().sum(dim=-1).log()).reshape(NH, num_tokens).t() # (T, NH)
# Compressed KV attention
topk_max = topk_indices.shape[-1] if topk_indices.dim() > 1 else 1
o_sparse = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
lse_sparse = torch.full((num_tokens, NH), float('-inf'), dtype=torch.float32, device=device)
if topk_max > 0 and topk_lens[:num_tokens].max().item() > 0:
comp_bs = compressed_kv_cache.shape[1]
safe_topk = topk_indices[:num_tokens].clamp(min=0)
comp_bi = safe_topk // comp_bs
comp_of = safe_topk % comp_bs
comp_raw = compressed_kv_cache[comp_bi, comp_of]
if compressed_kv_cache.dtype == torch.uint8:
comp_raw = comp_raw.view(torch.float8_e4m3fn)
comp_bf16 = (comp_raw.to(torch.bfloat16) * compressed_inv_scale[safe_topk]).to(torch.bfloat16)
topk_len_mask = torch.arange(topk_max, device=device).unsqueeze(0) >= topk_lens[:num_tokens].unsqueeze(1)
invalid_topk = topk_indices[:num_tokens] < 0
attn_mask_comp = topk_len_mask | invalid_topk
float_mask_comp = torch.zeros(attn_mask_comp.shape, dtype=torch.bfloat16, device=device)
float_mask_comp[attn_mask_comp] = float('-inf')
kv_exp2 = comp_bf16.unsqueeze(0).expand(NH, num_tokens, topk_max, HD)
k_batch2 = kv_exp2.reshape(NH * num_tokens, topk_max, HD)
mask_batch2 = float_mask_comp.unsqueeze(0).unsqueeze(2).expand(NH, num_tokens, 1, topk_max).reshape(NH * num_tokens, 1, topk_max)
o_sparse = F.scaled_dot_product_attention(q_batch, k_batch2, k_batch2, attn_mask=mask_batch2, is_causal=False, scale=scale)
o_sparse = o_sparse.reshape(NH, num_tokens, HD).permute(1, 0, 2)
scores_comp = torch.matmul(q_batch, k_batch2.transpose(-2, -1)) * scale
scores_comp = scores_comp + mask_batch2.float()
max_comp = scores_comp.max(dim=-1).values
lse_sparse = (max_comp + (scores_comp - max_comp.unsqueeze(-1)).exp().sum(dim=-1).log()).reshape(NH, num_tokens).t()
# Merge with sink weights
attn_sink = attn_sink.to(torch.float32) # (NH,)
exp_lse_sparse = lse_sparse.exp() # (T, NH)
exp_lse_swa = lse_swa.exp()
exp_sink = attn_sink.unsqueeze(0).exp() # (1, NH)
numerator = (exp_lse_sparse.unsqueeze(-1) * o_sparse.float() +
exp_sink.unsqueeze(-1) * exp_lse_swa.unsqueeze(-1) * o_swa.float())
denominator = (exp_lse_sparse + exp_sink * exp_lse_swa).clamp(min=1e-30).unsqueeze(-1)
output = (numerator / denominator).to(torch.bfloat16)
return output
if HAS_CUTEDSL:
class BlackwellSparseDecodeKernel:
def __init__(self, head_dim=HEAD_DIM, head_group=HEAD_GROUP,
kv_tile=KV_TILE, total_len=128):
self._head_dim = head_dim
self._head_group = head_group
self._kv_tile = kv_tile
self._total_len = total_len
self._num_threads = NUM_THREADS
@cute.jit
def __call__(self, mQ, mKV, mLens, mO, mScale, stream):
num_tokens = mQ.shape[0]
num_head_groups = mQ.shape[1] // self._head_group
self._kernel(mQ, mKV, mLens, mO, mScale).launch(
grid=(num_head_groups, num_tokens, 1),
block=[self._num_threads, 1, 1],
stream=stream,
)
@cute.kernel
def _kernel(self, mQ, mKV, mLens, mO, mScale):
tidx, _, _ = cute.arch.thread_idx()
hg_idx, tok_idx, _ = cute.arch.block_idx()
HG = self._head_group
HD = self._head_dim
KT = self._kv_tile
TL = self._total_len
softmax_scale = mScale[0]
@cute.struct
class SharedStorage:
kv_tile: cute.struct.MemRange[cutlass.BFloat16, KT * HD]
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
sKV = cute.make_tensor(
storage.kv_tile.data_ptr(),
cute.make_layout((KT, HD), stride=(HD, 1)),
)
swa_len = mLens[tok_idx]
has_kv = swa_len > 0
q_reg = cute.make_rmem_tensor((HG, HD), cutlass.BFloat16)
for h in cutlass.range_constexpr(HG):
qh = hg_idx * HG + h
for d in range(HD):
q_reg[h, d] = mQ[tok_idx, qh, d]
acc_O = cute.make_rmem_tensor((HG, HD), cutlass.Float32)
acc_O.fill(0.0)
row_max = cute.make_rmem_tensor((HG,), cutlass.Float32)
row_sum = cute.make_rmem_tensor((HG,), cutlass.Float32)
row_max.fill(-1e30)
row_sum.fill(0.0)
max_tiles = (TL + KT - 1) // KT
for tile_idx in range(max_tiles):
tile_start = tile_idx * KT
for kv_pos in range(KT):
global_kv = tile_start + kv_pos
for d in range(HD):
valid = global_kv < swa_len
val = cutlass.BFloat16(0.0)
if valid:
val = mKV[tok_idx, global_kv, d]
sKV[kv_pos, d] = val
cute.arch.sync_threads()
scores = cute.make_rmem_tensor((HG, KT), cutlass.Float32)
scores.fill(0.0)
for h in cutlass.range_constexpr(HG):
for kv_pos in range(KT):
dot = cutlass.Float32(0.0)
for d in range(HD):
q_val = q_reg[h, d].to(cutlass.Float32)
k_val = sKV[kv_pos, d].to(cutlass.Float32)
dot = dot + q_val * k_val
scores[h, kv_pos] = dot * softmax_scale
for h in cutlass.range_constexpr(HG):
tile_max = cutlass.Float32(-1e30)
for kv_pos in range(KT):
s = scores[h, kv_pos]
if s > tile_max:
tile_max = s
new_max = row_max[h]
if tile_max > new_max:
new_max = tile_max
rescale = cutlass.Float32(0.0)
if row_max[h] > cutlass.Float32(-1e29):
rescale = cute.exp(row_max[h] - new_max)
for d in range(HD):
acc_O[h, d] = acc_O[h, d] * rescale
row_sum[h] = row_sum[h] * rescale
for kv_pos in range(KT):
exp_score = cute.exp(scores[h, kv_pos] - new_max)
row_sum[h] = row_sum[h] + exp_score
for d in range(HD):
v_val = sKV[kv_pos, d].to(cutlass.Float32)
acc_O[h, d] = acc_O[h, d] + exp_score * v_val
row_max[h] = new_max
cute.arch.sync_threads()
for h in cutlass.range_constexpr(HG):
qh = hg_idx * HG + h
for d in range(HD):
val_f32 = cutlass.Float32(0.0)
if has_kv and row_sum[h] > cutlass.Float32(1e-30):
val_f32 = acc_O[h, d] / row_sum[h]
mO[tok_idx, qh, d] = val_f32.to(cutlass.BFloat16)