Files
nvfp4-megamoe-kernel/tests/archive/test_decode_pipeline.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

141 lines
6.0 KiB
Python

#!/usr/bin/env python3
"""
Integration test: full decode attention pipeline on Blackwell.
Tests the end-to-end path that _attention_impl_blackwell uses:
1. Project Q, KV (simulated)
2. Apply RoPE to Q (in-place)
3. Write KV to paged cache (RoPE + fp8 quantize + insert)
4. Native SWA decode attention (CuTeDSL kernel)
5. Inverse RoPE on output
6. wo_a + wo_b projections
Compares against a pure-PyTorch reference path.
"""
import sys, torch, torch.nn.functional as F, math
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/vllm")
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel")
from vllm.model_executor.layers.csa_attention import (
fused_qnorm_rope_kv_insert_py,
blackwell_attention_kv_write,
causal_prefill_attention,
kv_dequantize_fp8,
apply_gptj_rope,
apply_inv_gptj_rope,
)
from dsv4.ops.decode_swa import native_swa_decode_attention
torch.manual_seed(42)
torch.cuda.set_device(0)
# ── Model params (DeepSeek-V4) ──────────────────────────────────────
NH = 128
HD = 512
NOPE_DIM = 448
ROPE_DIM = 64
BLOCK_SIZE = 256
WINDOW_SIZE = 128
NUM_LAYERS = 61
SCALE = HD ** -0.5
EPS = 1e-6
# ── Cos/sin cache ────────────────────────────────────────────────────
MAX_POS = 4096
half_rope = ROPE_DIM // 2
freqs = 1.0 / (10000 ** (torch.arange(0, ROPE_DIM, 2).float() / ROPE_DIM))
t = torch.arange(MAX_POS).float()
freqs = torch.outer(t, freqs)
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (MAX_POS, ROPE_DIM)
# ── Simulate decode tokens ──────────────────────────────────────────
num_decode_tokens = 4
positions = torch.tensor([100, 200, 300, 400], dtype=torch.int64, device="cuda:0")
# Create Q and KV (post-norm, pre-RoPE)
q = torch.randn(num_decode_tokens, NH, HD, dtype=torch.bfloat16, device="cuda:0") * 0.1
kv = torch.randn(num_decode_tokens, HD, dtype=torch.bfloat16, device="cuda:0") * 0.5
# ── Apply RoPE to Q ─────────────────────────────────────────────────
fused_qnorm_rope_kv_insert_py(
q, kv, None, None, positions, cos_sin_cache, EPS, 0,
nope_dim=NOPE_DIM, rope_dim=ROPE_DIM,
)
# q is now RoPE'd in-place
# ── Create paged KV cache and write KV ──────────────────────────────
num_blocks = 8
swa_kv_cache = torch.zeros(
num_blocks, BLOCK_SIZE, HD, dtype=torch.uint8, device="cuda:0",
)
max_slots = num_blocks * BLOCK_SIZE
swa_inv_scale = torch.zeros(max_slots, 1, dtype=torch.bfloat16, device="cuda:0")
# Slot mapping: each decode token gets a unique slot
slot_mapping = torch.zeros(num_decode_tokens, dtype=torch.int64, device="cuda:0")
for i, pos in enumerate(positions):
slot_mapping[i] = pos.item() # slot = position for simplicity
blackwell_attention_kv_write(
kv, positions, swa_kv_cache, swa_inv_scale,
slot_mapping, BLOCK_SIZE, cos_sin_cache,
nope_dim=NOPE_DIM, rope_dim=ROPE_DIM,
)
# ── Build SWA indices for decode ─────────────────────────────────────
# Each decode token attends to the last window_size positions
swa_indices = torch.zeros(num_decode_tokens, WINDOW_SIZE, dtype=torch.int64, device="cuda:0")
swa_lens = torch.zeros(num_decode_tokens, dtype=torch.int64, device="cuda:0")
for i, pos in enumerate(positions):
# This token can see positions 0..pos (inclusive)
num_cached = min(pos.item() + 1, WINDOW_SIZE)
swa_lens[i] = num_cached
for j in range(WINDOW_SIZE):
if j < num_cached:
slot = pos.item() - (num_cached - 1 - j)
swa_indices[i, j] = max(0, slot)
else:
swa_indices[i, j] = -1
# ── Native SWA decode attention ──────────────────────────────────────
o_native = native_swa_decode_attention(
q, swa_kv_cache, swa_inv_scale,
swa_indices, swa_lens,
BLOCK_SIZE, SCALE, WINDOW_SIZE,
)
# ── Reference: full BF16 attention ──────────────────────────────────
# Read all cached KV for each token, dequantize, attend
o_ref = torch.zeros_like(o_native)
for i, pos in enumerate(positions):
num_cached = min(pos.item() + 1, WINDOW_SIZE)
slots = torch.arange(pos.item() - num_cached + 1, pos.item() + 1, dtype=torch.int64, device="cuda:0")
slots = slots.clamp(min=0)
block_idx = slots // BLOCK_SIZE
offsets = slots % BLOCK_SIZE
kv_cached_raw = swa_kv_cache[block_idx, offsets].view(torch.float8_e4m3fn)
inv_s = swa_inv_scale[slots]
kv_cached = (kv_cached_raw.to(torch.bfloat16) * inv_s).to(torch.bfloat16)
qi = q[i:i+1] # (1, NH, HD)
qi_t = qi.permute(1, 0, 2) # (NH, 1, HD)
kv_exp = kv_cached.unsqueeze(0).expand(NH, -1, -1)
out = F.scaled_dot_product_attention(qi_t, kv_exp, kv_exp, is_causal=False, scale=SCALE)
o_ref[i] = out.permute(1, 0, 2).squeeze(0)
# ── Compare ──────────────────────────────────────────────────────────
cos = F.cosine_similarity(o_ref.flatten().unsqueeze(0).float(),
o_native.flatten().unsqueeze(0).float()).item()
print(f"Full pipeline cosine (ref vs native): {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
# Per-token
for i in range(num_decode_tokens):
ct = F.cosine_similarity(o_ref[i].flatten().unsqueeze(0).float(),
o_native[i].flatten().unsqueeze(0).float()).item()
print(f" Token {i} (pos={positions[i].item()}) cosine: {ct:.6f}")
# Check for NaN
print(f"NaN in native output: {torch.isnan(o_native).any()}")
print(f"Native amax: {o_native.amax():.4f}")