#!/usr/bin/env python3 """ Integration test: full decode attention pipeline on Blackwell. Tests the end-to-end path that _attention_impl_blackwell uses: 1. Project Q, KV (simulated) 2. Apply RoPE to Q (in-place) 3. Write KV to paged cache (RoPE + fp8 quantize + insert) 4. Native SWA decode attention (CuTeDSL kernel) 5. Inverse RoPE on output 6. wo_a + wo_b projections Compares against a pure-PyTorch reference path. """ import sys, torch, torch.nn.functional as F, math sys.path.insert(0, "/root/dsv4-nvfp4-workspace/vllm") sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel") from vllm.model_executor.layers.csa_attention import ( fused_qnorm_rope_kv_insert_py, blackwell_attention_kv_write, causal_prefill_attention, kv_dequantize_fp8, apply_gptj_rope, apply_inv_gptj_rope, ) from cutedsl.native_swa_decode import native_swa_decode_attention torch.manual_seed(42) torch.cuda.set_device(0) # ── Model params (DeepSeek-V4) ────────────────────────────────────── NH = 128 HD = 512 NOPE_DIM = 448 ROPE_DIM = 64 BLOCK_SIZE = 256 WINDOW_SIZE = 128 NUM_LAYERS = 61 SCALE = HD ** -0.5 EPS = 1e-6 # ── Cos/sin cache ──────────────────────────────────────────────────── MAX_POS = 4096 half_rope = ROPE_DIM // 2 freqs = 1.0 / (10000 ** (torch.arange(0, ROPE_DIM, 2).float() / ROPE_DIM)) t = torch.arange(MAX_POS).float() freqs = torch.outer(t, freqs) cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (MAX_POS, ROPE_DIM) # ── Simulate decode tokens ────────────────────────────────────────── num_decode_tokens = 4 positions = torch.tensor([100, 200, 300, 400], dtype=torch.int64, device="cuda:0") # Create Q and KV (post-norm, pre-RoPE) q = torch.randn(num_decode_tokens, NH, HD, dtype=torch.bfloat16, device="cuda:0") * 0.1 kv = torch.randn(num_decode_tokens, HD, dtype=torch.bfloat16, device="cuda:0") * 0.5 # ── Apply RoPE to Q ───────────────────────────────────────────────── fused_qnorm_rope_kv_insert_py( q, kv, None, None, positions, cos_sin_cache, EPS, 0, nope_dim=NOPE_DIM, rope_dim=ROPE_DIM, ) # q is now RoPE'd in-place # ── Create paged KV cache and write KV ────────────────────────────── num_blocks = 8 swa_kv_cache = torch.zeros( num_blocks, BLOCK_SIZE, HD, dtype=torch.uint8, device="cuda:0", ) max_slots = num_blocks * BLOCK_SIZE swa_inv_scale = torch.zeros(max_slots, 1, dtype=torch.bfloat16, device="cuda:0") # Slot mapping: each decode token gets a unique slot slot_mapping = torch.zeros(num_decode_tokens, dtype=torch.int64, device="cuda:0") for i, pos in enumerate(positions): slot_mapping[i] = pos.item() # slot = position for simplicity blackwell_attention_kv_write( kv, positions, swa_kv_cache, swa_inv_scale, slot_mapping, BLOCK_SIZE, cos_sin_cache, nope_dim=NOPE_DIM, rope_dim=ROPE_DIM, ) # ── Build SWA indices for decode ───────────────────────────────────── # Each decode token attends to the last window_size positions swa_indices = torch.zeros(num_decode_tokens, WINDOW_SIZE, dtype=torch.int64, device="cuda:0") swa_lens = torch.zeros(num_decode_tokens, dtype=torch.int64, device="cuda:0") for i, pos in enumerate(positions): # This token can see positions 0..pos (inclusive) num_cached = min(pos.item() + 1, WINDOW_SIZE) swa_lens[i] = num_cached for j in range(WINDOW_SIZE): if j < num_cached: slot = pos.item() - (num_cached - 1 - j) swa_indices[i, j] = max(0, slot) else: swa_indices[i, j] = -1 # ── Native SWA decode attention ────────────────────────────────────── o_native = native_swa_decode_attention( q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens, BLOCK_SIZE, SCALE, WINDOW_SIZE, ) # ── Reference: full BF16 attention ────────────────────────────────── # Read all cached KV for each token, dequantize, attend o_ref = torch.zeros_like(o_native) for i, pos in enumerate(positions): num_cached = min(pos.item() + 1, WINDOW_SIZE) slots = torch.arange(pos.item() - num_cached + 1, pos.item() + 1, dtype=torch.int64, device="cuda:0") slots = slots.clamp(min=0) block_idx = slots // BLOCK_SIZE offsets = slots % BLOCK_SIZE kv_cached_raw = swa_kv_cache[block_idx, offsets].view(torch.float8_e4m3fn) inv_s = swa_inv_scale[slots] kv_cached = (kv_cached_raw.to(torch.bfloat16) * inv_s).to(torch.bfloat16) qi = q[i:i+1] # (1, NH, HD) qi_t = qi.permute(1, 0, 2) # (NH, 1, HD) kv_exp = kv_cached.unsqueeze(0).expand(NH, -1, -1) out = F.scaled_dot_product_attention(qi_t, kv_exp, kv_exp, is_causal=False, scale=SCALE) o_ref[i] = out.permute(1, 0, 2).squeeze(0) # ── Compare ────────────────────────────────────────────────────────── cos = F.cosine_similarity(o_ref.flatten().unsqueeze(0).float(), o_native.flatten().unsqueeze(0).float()).item() print(f"Full pipeline cosine (ref vs native): {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}") # Per-token for i in range(num_decode_tokens): ct = F.cosine_similarity(o_ref[i].flatten().unsqueeze(0).float(), o_native[i].flatten().unsqueeze(0).float()).item() print(f" Token {i} (pos={positions[i].item()}) cosine: {ct:.6f}") # Check for NaN print(f"NaN in native output: {torch.isnan(o_native).any()}") print(f"Native amax: {o_native.amax():.4f}")