Files
nvfp4-megamoe-kernel/tests/test_decode_pipeline.py
biondizzle bbba289bd8 feat: GPU-native SWA + sparse decode attention kernels (CuTeDSL)
- native_swa_decode.py: BlackwellSWADecodeKernel
  - CTA mapping: 1 CTA per (decode_token, q_head_group)
  - Online softmax with KV tile streaming (16 tokens/tile)
  - Pre-dequantized bf16 KV (fp8 dequant on host - MLIR cvt_fpext
    requires 32-bit aligned vector, no scalar fp8->bf16 support)
  - Cosine 0.9999+ vs PyTorch batched SDPA reference
  - Fallback _fallback_batched_sdp when CuTeDSL unavailable

- native_sparse_decode.py: BlackwellSparseDecodeKernel
  - Combined SWA + compressed KV in single attention pass
  - Supports CSA (cr=4) and HCA (cr=128) layers
  - Sink weight merge on host side
  - Cosine 0.9999+ vs combined SDPA reference

- fp8_bf16.py: Documents MLIR limitation (cvt_fpext requires
  vector<4xf8>, no scalar support). Pre-dequant is the workaround.

- vLLM wiring (attention.py):
  - SWA-only layers: native_swa_decode_attention
  - CSA/HCA layers: native_sparse_decode_attention with topk + attn_sink
  - csa_attention.py updated to use native kernels

- Tests: test_decode_pipeline.py, test_sparse_decode.py both passing
2026-05-20 05:46:15 +00:00

141 lines
6.0 KiB
Python

#!/usr/bin/env python3
"""
Integration test: full decode attention pipeline on Blackwell.
Tests the end-to-end path that _attention_impl_blackwell uses:
1. Project Q, KV (simulated)
2. Apply RoPE to Q (in-place)
3. Write KV to paged cache (RoPE + fp8 quantize + insert)
4. Native SWA decode attention (CuTeDSL kernel)
5. Inverse RoPE on output
6. wo_a + wo_b projections
Compares against a pure-PyTorch reference path.
"""
import sys, torch, torch.nn.functional as F, math
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/vllm")
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel")
from vllm.model_executor.layers.csa_attention import (
fused_qnorm_rope_kv_insert_py,
blackwell_attention_kv_write,
causal_prefill_attention,
kv_dequantize_fp8,
apply_gptj_rope,
apply_inv_gptj_rope,
)
from cutedsl.native_swa_decode import native_swa_decode_attention
torch.manual_seed(42)
torch.cuda.set_device(0)
# ── Model params (DeepSeek-V4) ──────────────────────────────────────
NH = 128
HD = 512
NOPE_DIM = 448
ROPE_DIM = 64
BLOCK_SIZE = 256
WINDOW_SIZE = 128
NUM_LAYERS = 61
SCALE = HD ** -0.5
EPS = 1e-6
# ── Cos/sin cache ────────────────────────────────────────────────────
MAX_POS = 4096
half_rope = ROPE_DIM // 2
freqs = 1.0 / (10000 ** (torch.arange(0, ROPE_DIM, 2).float() / ROPE_DIM))
t = torch.arange(MAX_POS).float()
freqs = torch.outer(t, freqs)
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (MAX_POS, ROPE_DIM)
# ── Simulate decode tokens ──────────────────────────────────────────
num_decode_tokens = 4
positions = torch.tensor([100, 200, 300, 400], dtype=torch.int64, device="cuda:0")
# Create Q and KV (post-norm, pre-RoPE)
q = torch.randn(num_decode_tokens, NH, HD, dtype=torch.bfloat16, device="cuda:0") * 0.1
kv = torch.randn(num_decode_tokens, HD, dtype=torch.bfloat16, device="cuda:0") * 0.5
# ── Apply RoPE to Q ─────────────────────────────────────────────────
fused_qnorm_rope_kv_insert_py(
q, kv, None, None, positions, cos_sin_cache, EPS, 0,
nope_dim=NOPE_DIM, rope_dim=ROPE_DIM,
)
# q is now RoPE'd in-place
# ── Create paged KV cache and write KV ──────────────────────────────
num_blocks = 8
swa_kv_cache = torch.zeros(
num_blocks, BLOCK_SIZE, HD, dtype=torch.uint8, device="cuda:0",
)
max_slots = num_blocks * BLOCK_SIZE
swa_inv_scale = torch.zeros(max_slots, 1, dtype=torch.bfloat16, device="cuda:0")
# Slot mapping: each decode token gets a unique slot
slot_mapping = torch.zeros(num_decode_tokens, dtype=torch.int64, device="cuda:0")
for i, pos in enumerate(positions):
slot_mapping[i] = pos.item() # slot = position for simplicity
blackwell_attention_kv_write(
kv, positions, swa_kv_cache, swa_inv_scale,
slot_mapping, BLOCK_SIZE, cos_sin_cache,
nope_dim=NOPE_DIM, rope_dim=ROPE_DIM,
)
# ── Build SWA indices for decode ─────────────────────────────────────
# Each decode token attends to the last window_size positions
swa_indices = torch.zeros(num_decode_tokens, WINDOW_SIZE, dtype=torch.int64, device="cuda:0")
swa_lens = torch.zeros(num_decode_tokens, dtype=torch.int64, device="cuda:0")
for i, pos in enumerate(positions):
# This token can see positions 0..pos (inclusive)
num_cached = min(pos.item() + 1, WINDOW_SIZE)
swa_lens[i] = num_cached
for j in range(WINDOW_SIZE):
if j < num_cached:
slot = pos.item() - (num_cached - 1 - j)
swa_indices[i, j] = max(0, slot)
else:
swa_indices[i, j] = -1
# ── Native SWA decode attention ──────────────────────────────────────
o_native = native_swa_decode_attention(
q, swa_kv_cache, swa_inv_scale,
swa_indices, swa_lens,
BLOCK_SIZE, SCALE, WINDOW_SIZE,
)
# ── Reference: full BF16 attention ──────────────────────────────────
# Read all cached KV for each token, dequantize, attend
o_ref = torch.zeros_like(o_native)
for i, pos in enumerate(positions):
num_cached = min(pos.item() + 1, WINDOW_SIZE)
slots = torch.arange(pos.item() - num_cached + 1, pos.item() + 1, dtype=torch.int64, device="cuda:0")
slots = slots.clamp(min=0)
block_idx = slots // BLOCK_SIZE
offsets = slots % BLOCK_SIZE
kv_cached_raw = swa_kv_cache[block_idx, offsets].view(torch.float8_e4m3fn)
inv_s = swa_inv_scale[slots]
kv_cached = (kv_cached_raw.to(torch.bfloat16) * inv_s).to(torch.bfloat16)
qi = q[i:i+1] # (1, NH, HD)
qi_t = qi.permute(1, 0, 2) # (NH, 1, HD)
kv_exp = kv_cached.unsqueeze(0).expand(NH, -1, -1)
out = F.scaled_dot_product_attention(qi_t, kv_exp, kv_exp, is_causal=False, scale=SCALE)
o_ref[i] = out.permute(1, 0, 2).squeeze(0)
# ── Compare ──────────────────────────────────────────────────────────
cos = F.cosine_similarity(o_ref.flatten().unsqueeze(0).float(),
o_native.flatten().unsqueeze(0).float()).item()
print(f"Full pipeline cosine (ref vs native): {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
# Per-token
for i in range(num_decode_tokens):
ct = F.cosine_similarity(o_ref[i].flatten().unsqueeze(0).float(),
o_native[i].flatten().unsqueeze(0).float()).item()
print(f" Token {i} (pos={positions[i].item()}) cosine: {ct:.6f}")
# Check for NaN
print(f"NaN in native output: {torch.isnan(o_native).any()}")
print(f"Native amax: {o_native.amax():.4f}")