- native_swa_decode.py: BlackwellSWADecodeKernel
- CTA mapping: 1 CTA per (decode_token, q_head_group)
- Online softmax with KV tile streaming (16 tokens/tile)
- Pre-dequantized bf16 KV (fp8 dequant on host - MLIR cvt_fpext
requires 32-bit aligned vector, no scalar fp8->bf16 support)
- Cosine 0.9999+ vs PyTorch batched SDPA reference
- Fallback _fallback_batched_sdp when CuTeDSL unavailable
- native_sparse_decode.py: BlackwellSparseDecodeKernel
- Combined SWA + compressed KV in single attention pass
- Supports CSA (cr=4) and HCA (cr=128) layers
- Sink weight merge on host side
- Cosine 0.9999+ vs combined SDPA reference
- fp8_bf16.py: Documents MLIR limitation (cvt_fpext requires
vector<4xf8>, no scalar support). Pre-dequant is the workaround.
- vLLM wiring (attention.py):
- SWA-only layers: native_swa_decode_attention
- CSA/HCA layers: native_sparse_decode_attention with topk + attn_sink
- csa_attention.py updated to use native kernels
- Tests: test_decode_pipeline.py, test_sparse_decode.py both passing
141 lines
6.0 KiB
Python
141 lines
6.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Integration test: full decode attention pipeline on Blackwell.
|
|
|
|
Tests the end-to-end path that _attention_impl_blackwell uses:
|
|
1. Project Q, KV (simulated)
|
|
2. Apply RoPE to Q (in-place)
|
|
3. Write KV to paged cache (RoPE + fp8 quantize + insert)
|
|
4. Native SWA decode attention (CuTeDSL kernel)
|
|
5. Inverse RoPE on output
|
|
6. wo_a + wo_b projections
|
|
|
|
Compares against a pure-PyTorch reference path.
|
|
"""
|
|
import sys, torch, torch.nn.functional as F, math
|
|
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/vllm")
|
|
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel")
|
|
|
|
from vllm.model_executor.layers.csa_attention import (
|
|
fused_qnorm_rope_kv_insert_py,
|
|
blackwell_attention_kv_write,
|
|
causal_prefill_attention,
|
|
kv_dequantize_fp8,
|
|
apply_gptj_rope,
|
|
apply_inv_gptj_rope,
|
|
)
|
|
from cutedsl.native_swa_decode import native_swa_decode_attention
|
|
|
|
torch.manual_seed(42)
|
|
torch.cuda.set_device(0)
|
|
|
|
# ── Model params (DeepSeek-V4) ──────────────────────────────────────
|
|
NH = 128
|
|
HD = 512
|
|
NOPE_DIM = 448
|
|
ROPE_DIM = 64
|
|
BLOCK_SIZE = 256
|
|
WINDOW_SIZE = 128
|
|
NUM_LAYERS = 61
|
|
SCALE = HD ** -0.5
|
|
EPS = 1e-6
|
|
|
|
# ── Cos/sin cache ────────────────────────────────────────────────────
|
|
MAX_POS = 4096
|
|
half_rope = ROPE_DIM // 2
|
|
freqs = 1.0 / (10000 ** (torch.arange(0, ROPE_DIM, 2).float() / ROPE_DIM))
|
|
t = torch.arange(MAX_POS).float()
|
|
freqs = torch.outer(t, freqs)
|
|
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (MAX_POS, ROPE_DIM)
|
|
|
|
# ── Simulate decode tokens ──────────────────────────────────────────
|
|
num_decode_tokens = 4
|
|
positions = torch.tensor([100, 200, 300, 400], dtype=torch.int64, device="cuda:0")
|
|
|
|
# Create Q and KV (post-norm, pre-RoPE)
|
|
q = torch.randn(num_decode_tokens, NH, HD, dtype=torch.bfloat16, device="cuda:0") * 0.1
|
|
kv = torch.randn(num_decode_tokens, HD, dtype=torch.bfloat16, device="cuda:0") * 0.5
|
|
|
|
# ── Apply RoPE to Q ─────────────────────────────────────────────────
|
|
fused_qnorm_rope_kv_insert_py(
|
|
q, kv, None, None, positions, cos_sin_cache, EPS, 0,
|
|
nope_dim=NOPE_DIM, rope_dim=ROPE_DIM,
|
|
)
|
|
# q is now RoPE'd in-place
|
|
|
|
# ── Create paged KV cache and write KV ──────────────────────────────
|
|
num_blocks = 8
|
|
swa_kv_cache = torch.zeros(
|
|
num_blocks, BLOCK_SIZE, HD, dtype=torch.uint8, device="cuda:0",
|
|
)
|
|
max_slots = num_blocks * BLOCK_SIZE
|
|
swa_inv_scale = torch.zeros(max_slots, 1, dtype=torch.bfloat16, device="cuda:0")
|
|
|
|
# Slot mapping: each decode token gets a unique slot
|
|
slot_mapping = torch.zeros(num_decode_tokens, dtype=torch.int64, device="cuda:0")
|
|
for i, pos in enumerate(positions):
|
|
slot_mapping[i] = pos.item() # slot = position for simplicity
|
|
|
|
blackwell_attention_kv_write(
|
|
kv, positions, swa_kv_cache, swa_inv_scale,
|
|
slot_mapping, BLOCK_SIZE, cos_sin_cache,
|
|
nope_dim=NOPE_DIM, rope_dim=ROPE_DIM,
|
|
)
|
|
|
|
# ── Build SWA indices for decode ─────────────────────────────────────
|
|
# Each decode token attends to the last window_size positions
|
|
swa_indices = torch.zeros(num_decode_tokens, WINDOW_SIZE, dtype=torch.int64, device="cuda:0")
|
|
swa_lens = torch.zeros(num_decode_tokens, dtype=torch.int64, device="cuda:0")
|
|
|
|
for i, pos in enumerate(positions):
|
|
# This token can see positions 0..pos (inclusive)
|
|
num_cached = min(pos.item() + 1, WINDOW_SIZE)
|
|
swa_lens[i] = num_cached
|
|
for j in range(WINDOW_SIZE):
|
|
if j < num_cached:
|
|
slot = pos.item() - (num_cached - 1 - j)
|
|
swa_indices[i, j] = max(0, slot)
|
|
else:
|
|
swa_indices[i, j] = -1
|
|
|
|
# ── Native SWA decode attention ──────────────────────────────────────
|
|
o_native = native_swa_decode_attention(
|
|
q, swa_kv_cache, swa_inv_scale,
|
|
swa_indices, swa_lens,
|
|
BLOCK_SIZE, SCALE, WINDOW_SIZE,
|
|
)
|
|
|
|
# ── Reference: full BF16 attention ──────────────────────────────────
|
|
# Read all cached KV for each token, dequantize, attend
|
|
o_ref = torch.zeros_like(o_native)
|
|
for i, pos in enumerate(positions):
|
|
num_cached = min(pos.item() + 1, WINDOW_SIZE)
|
|
slots = torch.arange(pos.item() - num_cached + 1, pos.item() + 1, dtype=torch.int64, device="cuda:0")
|
|
slots = slots.clamp(min=0)
|
|
block_idx = slots // BLOCK_SIZE
|
|
offsets = slots % BLOCK_SIZE
|
|
kv_cached_raw = swa_kv_cache[block_idx, offsets].view(torch.float8_e4m3fn)
|
|
inv_s = swa_inv_scale[slots]
|
|
kv_cached = (kv_cached_raw.to(torch.bfloat16) * inv_s).to(torch.bfloat16)
|
|
|
|
qi = q[i:i+1] # (1, NH, HD)
|
|
qi_t = qi.permute(1, 0, 2) # (NH, 1, HD)
|
|
kv_exp = kv_cached.unsqueeze(0).expand(NH, -1, -1)
|
|
out = F.scaled_dot_product_attention(qi_t, kv_exp, kv_exp, is_causal=False, scale=SCALE)
|
|
o_ref[i] = out.permute(1, 0, 2).squeeze(0)
|
|
|
|
# ── Compare ──────────────────────────────────────────────────────────
|
|
cos = F.cosine_similarity(o_ref.flatten().unsqueeze(0).float(),
|
|
o_native.flatten().unsqueeze(0).float()).item()
|
|
print(f"Full pipeline cosine (ref vs native): {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
|
|
|
|
# Per-token
|
|
for i in range(num_decode_tokens):
|
|
ct = F.cosine_similarity(o_ref[i].flatten().unsqueeze(0).float(),
|
|
o_native[i].flatten().unsqueeze(0).float()).item()
|
|
print(f" Token {i} (pos={positions[i].item()}) cosine: {ct:.6f}")
|
|
|
|
# Check for NaN
|
|
print(f"NaN in native output: {torch.isnan(o_native).any()}")
|
|
print(f"Native amax: {o_native.amax():.4f}")
|