- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
354 lines
15 KiB
Python
354 lines
15 KiB
Python
"""
|
|
Native CuTeDSL Sparse SWA Decode Attention for DeepSeek-V4 on Blackwell (SM100).
|
|
|
|
Handles CSA (C4A, compress_ratio=4) and HCA (C128A, compress_ratio=128).
|
|
Attends to BOTH the SWA window AND top-k compressed KV, merged with sink weights.
|
|
|
|
Sink weight merge (FlashMLA formula):
|
|
o = exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa
|
|
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
|
|
|
|
where o_sparse = sum(exp(s)*v) / sum(exp(s)) from compressed KV
|
|
o_swa = sum(exp(s)*v) / sum(exp(s)) from SWA KV
|
|
lse_sparse = log(sum(exp(s))) from compressed KV
|
|
lse_swa = log(sum(exp(s))) from SWA KV
|
|
attn_sink = per-head learnable parameter (NH,)
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from typing import Optional
|
|
|
|
try:
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as cutlass_torch
|
|
import cutlass.utils as utils
|
|
import cuda.bindings.driver as cuda
|
|
HAS_CUTEDSL = True
|
|
except ImportError:
|
|
HAS_CUTEDSL = False
|
|
|
|
_compiled_sparse_kernel_cache = {}
|
|
|
|
HEAD_GROUP = 16
|
|
KV_TILE = 16
|
|
HEAD_DIM = 512
|
|
NUM_THREADS = 128
|
|
|
|
|
|
def native_sparse_decode_attention(
|
|
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
|
|
compressed_kv_cache, compressed_inv_scale, topk_indices, topk_lens,
|
|
attn_sink,
|
|
block_size, scale, window_size=128, compress_ratio=4,
|
|
):
|
|
num_tokens, NH, HD = q.shape
|
|
device = q.device
|
|
|
|
if not HAS_CUTEDSL:
|
|
return _fallback_sparse_sdp(
|
|
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
|
|
compressed_kv_cache, compressed_inv_scale, topk_indices, topk_lens,
|
|
attn_sink, block_size, scale, window_size,
|
|
)
|
|
|
|
q = q.contiguous()
|
|
swa_indices = swa_indices.contiguous()
|
|
swa_lens = swa_lens.contiguous()
|
|
topk_indices = topk_indices.contiguous()
|
|
topk_lens = topk_lens.contiguous()
|
|
|
|
# Pre-dequantize SWA KV
|
|
swa_len_max = min(swa_lens[:num_tokens].max().item(), window_size)
|
|
topk_max = topk_indices.shape[-1] if topk_indices.dim() > 1 else 1
|
|
topk_len_max = min(topk_lens[:num_tokens].max().item(), topk_max) if topk_max > 0 else 0
|
|
|
|
if swa_len_max <= 0 and topk_len_max <= 0:
|
|
return torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
|
|
|
|
# Dequantize SWA KV
|
|
safe_swa = swa_indices[:num_tokens, :swa_len_max].clamp(min=0)
|
|
swa_bi = safe_swa // block_size
|
|
swa_of = safe_swa % block_size
|
|
swa_raw = swa_kv_cache[swa_bi, swa_of]
|
|
if swa_kv_cache.dtype == torch.uint8:
|
|
swa_raw = swa_raw.view(torch.float8_e4m3fn)
|
|
swa_bf16 = (swa_raw.to(torch.bfloat16) * swa_inv_scale[safe_swa]).to(torch.bfloat16)
|
|
if swa_len_max < window_size:
|
|
swa_bf16 = torch.cat([swa_bf16, torch.zeros(num_tokens, window_size - swa_len_max, HD, dtype=torch.bfloat16, device=device)], dim=1)
|
|
|
|
# Dequantize compressed KV
|
|
if topk_len_max > 0:
|
|
comp_bs = compressed_kv_cache.shape[1]
|
|
safe_topk = topk_indices[:num_tokens, :topk_len_max].clamp(min=0)
|
|
comp_bi = safe_topk // comp_bs
|
|
comp_of = safe_topk % comp_bs
|
|
comp_raw = compressed_kv_cache[comp_bi, comp_of]
|
|
if compressed_kv_cache.dtype == torch.uint8:
|
|
comp_raw = comp_raw.view(torch.float8_e4m3fn)
|
|
comp_bf16 = (comp_raw.to(torch.bfloat16) * compressed_inv_scale[safe_topk]).to(torch.bfloat16)
|
|
if topk_len_max < topk_max:
|
|
comp_bf16 = torch.cat([comp_bf16, torch.zeros(num_tokens, topk_max - topk_len_max, HD, dtype=torch.bfloat16, device=device)], dim=1)
|
|
else:
|
|
topk_max = 0
|
|
comp_bf16 = torch.zeros(num_tokens, 0, HD, dtype=torch.bfloat16, device=device)
|
|
|
|
# Combined KV: (T, window_size + topk_max, HD)
|
|
if topk_max > 0:
|
|
kv_combined = torch.cat([swa_bf16, comp_bf16], dim=1)
|
|
else:
|
|
kv_combined = swa_bf16
|
|
|
|
combined_lens = swa_lens[:num_tokens] + topk_lens[:num_tokens]
|
|
total_len = window_size + topk_max
|
|
|
|
output = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
|
|
|
|
cache_key = (num_tokens, NH, HD, window_size, topk_max, compress_ratio, str(device))
|
|
|
|
if cache_key not in _compiled_sparse_kernel_cache:
|
|
def to_cute(t):
|
|
ct = cutlass_torch.from_dlpack(t)
|
|
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
|
|
|
q_c = to_cute(q)
|
|
kv_c = to_cute(kv_combined)
|
|
len_c = to_cute(combined_lens)
|
|
out_c = to_cute(output)
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
scale_tensor = torch.tensor([scale], dtype=torch.float32, device=device)
|
|
scale_c = to_cute(scale_tensor)
|
|
|
|
kernel = BlackwellSparseDecodeKernel(
|
|
head_dim=HD, head_group=HEAD_GROUP, kv_tile=KV_TILE,
|
|
total_len=total_len,
|
|
)
|
|
compiled = cute.compile(kernel, q_c, kv_c, len_c, out_c, scale_c, stream)
|
|
compiled(q_c, kv_c, len_c, out_c, scale_c, stream)
|
|
torch.cuda.synchronize()
|
|
_compiled_sparse_kernel_cache[cache_key] = {'compiled': compiled}
|
|
|
|
entry = _compiled_sparse_kernel_cache[cache_key]
|
|
compiled = entry['compiled']
|
|
|
|
def to_cute(t):
|
|
ct = cutlass_torch.from_dlpack(t)
|
|
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
|
|
|
q_c = to_cute(q)
|
|
kv_c = to_cute(kv_combined)
|
|
len_c = to_cute(combined_lens)
|
|
out_c = to_cute(output)
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
scale_tensor = torch.tensor([scale], dtype=torch.float32, device=device)
|
|
scale_c = to_cute(scale_tensor)
|
|
|
|
compiled(q_c, kv_c, len_c, out_c, scale_c, stream)
|
|
return output
|
|
|
|
|
|
def _fallback_sparse_sdp(
|
|
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
|
|
compressed_kv_cache, compressed_inv_scale, topk_indices, topk_lens,
|
|
attn_sink, block_size, scale, window_size,
|
|
):
|
|
num_tokens, NH, HD = q.shape
|
|
device = q.device
|
|
|
|
if swa_indices.dim() == 3:
|
|
swa_indices = swa_indices.squeeze(0)
|
|
safe_swa = swa_indices[:num_tokens].clamp(min=0)
|
|
swa_bi = safe_swa // block_size
|
|
swa_of = safe_swa % block_size
|
|
swa_raw = swa_kv_cache[swa_bi, swa_of]
|
|
if swa_kv_cache.dtype == torch.uint8:
|
|
swa_raw = swa_raw.view(torch.float8_e4m3fn)
|
|
swa_bf16 = (swa_raw.to(torch.bfloat16) * swa_inv_scale[safe_swa]).to(torch.bfloat16)
|
|
|
|
# SWA attention (batched)
|
|
pos_range = torch.arange(window_size, device=device).unsqueeze(0)
|
|
len_mask = pos_range >= swa_lens[:num_tokens].unsqueeze(1)
|
|
invalid_mask = swa_indices[:num_tokens] < 0
|
|
attn_mask_swa = len_mask | invalid_mask
|
|
float_mask = torch.zeros(attn_mask_swa.shape, dtype=torch.bfloat16, device=device)
|
|
float_mask[attn_mask_swa] = float('-inf')
|
|
|
|
q_t = q.permute(1, 0, 2)
|
|
q_batch = q_t.reshape(NH * num_tokens, 1, HD)
|
|
kv_exp = swa_bf16.unsqueeze(0).expand(NH, num_tokens, window_size, HD)
|
|
k_batch = kv_exp.reshape(NH * num_tokens, window_size, HD)
|
|
mask_batch = float_mask.unsqueeze(0).unsqueeze(2).expand(NH, num_tokens, 1, window_size).reshape(NH * num_tokens, 1, window_size)
|
|
|
|
o_swa = F.scaled_dot_product_attention(q_batch, k_batch, k_batch, attn_mask=mask_batch, is_causal=False, scale=scale)
|
|
o_swa = o_swa.reshape(NH, num_tokens, HD).permute(1, 0, 2)
|
|
|
|
# Compute SWA lse manually
|
|
scores_swa = torch.matmul(q_batch, k_batch.transpose(-2, -1)) * scale
|
|
scores_swa = scores_swa + mask_batch.float()
|
|
max_swa = scores_swa.max(dim=-1).values # (NH*T,)
|
|
lse_swa = (max_swa + (scores_swa - max_swa.unsqueeze(-1)).exp().sum(dim=-1).log()).reshape(NH, num_tokens).t() # (T, NH)
|
|
|
|
# Compressed KV attention
|
|
topk_max = topk_indices.shape[-1] if topk_indices.dim() > 1 else 1
|
|
o_sparse = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
|
|
lse_sparse = torch.full((num_tokens, NH), float('-inf'), dtype=torch.float32, device=device)
|
|
|
|
if topk_max > 0 and topk_lens[:num_tokens].max().item() > 0:
|
|
comp_bs = compressed_kv_cache.shape[1]
|
|
safe_topk = topk_indices[:num_tokens].clamp(min=0)
|
|
comp_bi = safe_topk // comp_bs
|
|
comp_of = safe_topk % comp_bs
|
|
comp_raw = compressed_kv_cache[comp_bi, comp_of]
|
|
if compressed_kv_cache.dtype == torch.uint8:
|
|
comp_raw = comp_raw.view(torch.float8_e4m3fn)
|
|
comp_bf16 = (comp_raw.to(torch.bfloat16) * compressed_inv_scale[safe_topk]).to(torch.bfloat16)
|
|
|
|
topk_len_mask = torch.arange(topk_max, device=device).unsqueeze(0) >= topk_lens[:num_tokens].unsqueeze(1)
|
|
invalid_topk = topk_indices[:num_tokens] < 0
|
|
attn_mask_comp = topk_len_mask | invalid_topk
|
|
float_mask_comp = torch.zeros(attn_mask_comp.shape, dtype=torch.bfloat16, device=device)
|
|
float_mask_comp[attn_mask_comp] = float('-inf')
|
|
|
|
kv_exp2 = comp_bf16.unsqueeze(0).expand(NH, num_tokens, topk_max, HD)
|
|
k_batch2 = kv_exp2.reshape(NH * num_tokens, topk_max, HD)
|
|
mask_batch2 = float_mask_comp.unsqueeze(0).unsqueeze(2).expand(NH, num_tokens, 1, topk_max).reshape(NH * num_tokens, 1, topk_max)
|
|
|
|
o_sparse = F.scaled_dot_product_attention(q_batch, k_batch2, k_batch2, attn_mask=mask_batch2, is_causal=False, scale=scale)
|
|
o_sparse = o_sparse.reshape(NH, num_tokens, HD).permute(1, 0, 2)
|
|
|
|
scores_comp = torch.matmul(q_batch, k_batch2.transpose(-2, -1)) * scale
|
|
scores_comp = scores_comp + mask_batch2.float()
|
|
max_comp = scores_comp.max(dim=-1).values
|
|
lse_sparse = (max_comp + (scores_comp - max_comp.unsqueeze(-1)).exp().sum(dim=-1).log()).reshape(NH, num_tokens).t()
|
|
|
|
# Merge with sink weights
|
|
attn_sink = attn_sink.to(torch.float32) # (NH,)
|
|
exp_lse_sparse = lse_sparse.exp() # (T, NH)
|
|
exp_lse_swa = lse_swa.exp()
|
|
exp_sink = attn_sink.unsqueeze(0).exp() # (1, NH)
|
|
|
|
numerator = (exp_lse_sparse.unsqueeze(-1) * o_sparse.float() +
|
|
exp_sink.unsqueeze(-1) * exp_lse_swa.unsqueeze(-1) * o_swa.float())
|
|
denominator = (exp_lse_sparse + exp_sink * exp_lse_swa).clamp(min=1e-30).unsqueeze(-1)
|
|
output = (numerator / denominator).to(torch.bfloat16)
|
|
return output
|
|
|
|
|
|
if HAS_CUTEDSL:
|
|
|
|
class BlackwellSparseDecodeKernel:
|
|
def __init__(self, head_dim=HEAD_DIM, head_group=HEAD_GROUP,
|
|
kv_tile=KV_TILE, total_len=128):
|
|
self._head_dim = head_dim
|
|
self._head_group = head_group
|
|
self._kv_tile = kv_tile
|
|
self._total_len = total_len
|
|
self._num_threads = NUM_THREADS
|
|
|
|
@cute.jit
|
|
def __call__(self, mQ, mKV, mLens, mO, mScale, stream):
|
|
num_tokens = mQ.shape[0]
|
|
num_head_groups = mQ.shape[1] // self._head_group
|
|
self._kernel(mQ, mKV, mLens, mO, mScale).launch(
|
|
grid=(num_head_groups, num_tokens, 1),
|
|
block=[self._num_threads, 1, 1],
|
|
stream=stream,
|
|
)
|
|
|
|
@cute.kernel
|
|
def _kernel(self, mQ, mKV, mLens, mO, mScale):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
hg_idx, tok_idx, _ = cute.arch.block_idx()
|
|
|
|
HG = self._head_group
|
|
HD = self._head_dim
|
|
KT = self._kv_tile
|
|
TL = self._total_len
|
|
softmax_scale = mScale[0]
|
|
|
|
@cute.struct
|
|
class SharedStorage:
|
|
kv_tile: cute.struct.MemRange[cutlass.BFloat16, KT * HD]
|
|
|
|
smem = utils.SmemAllocator()
|
|
storage = smem.allocate(SharedStorage)
|
|
sKV = cute.make_tensor(
|
|
storage.kv_tile.data_ptr(),
|
|
cute.make_layout((KT, HD), stride=(HD, 1)),
|
|
)
|
|
|
|
swa_len = mLens[tok_idx]
|
|
has_kv = swa_len > 0
|
|
|
|
q_reg = cute.make_rmem_tensor((HG, HD), cutlass.BFloat16)
|
|
for h in cutlass.range_constexpr(HG):
|
|
qh = hg_idx * HG + h
|
|
for d in range(HD):
|
|
q_reg[h, d] = mQ[tok_idx, qh, d]
|
|
|
|
acc_O = cute.make_rmem_tensor((HG, HD), cutlass.Float32)
|
|
acc_O.fill(0.0)
|
|
row_max = cute.make_rmem_tensor((HG,), cutlass.Float32)
|
|
row_sum = cute.make_rmem_tensor((HG,), cutlass.Float32)
|
|
row_max.fill(-1e30)
|
|
row_sum.fill(0.0)
|
|
|
|
max_tiles = (TL + KT - 1) // KT
|
|
for tile_idx in range(max_tiles):
|
|
tile_start = tile_idx * KT
|
|
for kv_pos in range(KT):
|
|
global_kv = tile_start + kv_pos
|
|
for d in range(HD):
|
|
valid = global_kv < swa_len
|
|
val = cutlass.BFloat16(0.0)
|
|
if valid:
|
|
val = mKV[tok_idx, global_kv, d]
|
|
sKV[kv_pos, d] = val
|
|
|
|
cute.arch.sync_threads()
|
|
|
|
scores = cute.make_rmem_tensor((HG, KT), cutlass.Float32)
|
|
scores.fill(0.0)
|
|
for h in cutlass.range_constexpr(HG):
|
|
for kv_pos in range(KT):
|
|
dot = cutlass.Float32(0.0)
|
|
for d in range(HD):
|
|
q_val = q_reg[h, d].to(cutlass.Float32)
|
|
k_val = sKV[kv_pos, d].to(cutlass.Float32)
|
|
dot = dot + q_val * k_val
|
|
scores[h, kv_pos] = dot * softmax_scale
|
|
|
|
for h in cutlass.range_constexpr(HG):
|
|
tile_max = cutlass.Float32(-1e30)
|
|
for kv_pos in range(KT):
|
|
s = scores[h, kv_pos]
|
|
if s > tile_max:
|
|
tile_max = s
|
|
new_max = row_max[h]
|
|
if tile_max > new_max:
|
|
new_max = tile_max
|
|
rescale = cutlass.Float32(0.0)
|
|
if row_max[h] > cutlass.Float32(-1e29):
|
|
rescale = cute.exp(row_max[h] - new_max)
|
|
for d in range(HD):
|
|
acc_O[h, d] = acc_O[h, d] * rescale
|
|
row_sum[h] = row_sum[h] * rescale
|
|
for kv_pos in range(KT):
|
|
exp_score = cute.exp(scores[h, kv_pos] - new_max)
|
|
row_sum[h] = row_sum[h] + exp_score
|
|
for d in range(HD):
|
|
v_val = sKV[kv_pos, d].to(cutlass.Float32)
|
|
acc_O[h, d] = acc_O[h, d] + exp_score * v_val
|
|
row_max[h] = new_max
|
|
|
|
cute.arch.sync_threads()
|
|
|
|
for h in cutlass.range_constexpr(HG):
|
|
qh = hg_idx * HG + h
|
|
for d in range(HD):
|
|
val_f32 = cutlass.Float32(0.0)
|
|
if has_kv and row_sum[h] > cutlass.Float32(1e-30):
|
|
val_f32 = acc_O[h, d] / row_sum[h]
|
|
mO[tok_idx, qh, d] = val_f32.to(cutlass.BFloat16)
|