Files
nvfp4-megamoe-kernel/dsv4/reference/attention.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

248 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
DeepSeek-V4 Blackwell Attention — Our own kernel.
Replaces vLLM's broken FlashMLA Blackwell path with a proper KV cache-based
attention pipeline. Does NOT depend on FlashMLA, fp8_ds_mla, or any vLLM
fused CUDA kernel.
Architecture:
- KV: (T, HD=512) single head latent, shared across all 128 Q heads
- KV Cache: fp8_e4m3 paged cache with per-token inverse scale
- RoPE: GPT-J style, applied to Q and KV before caching
- Attention: BF16 (NVFP4 is too lossy for Q×K^T, cosine 0.86)
- CSA/HCA: Compressed KV for sparse attention (compress_ratio 4 or 128)
- SWA: Sliding window attention (compress_ratio 0/1)
Pipeline:
Prefill:
1. hidden → q_a_proj → q_norm → q_b_proj → (T, NH, HD) → RoPE on Q
2. hidden → kv_proj → kv_norm → (T, HD) → RoPE → fp8 quant → write to paged cache
3. Read all cached KV → BF16 causal attention → output
Decode:
1. Same projections as prefill
2. Write new KV to cache
3. Read ALL cached KV → BF16 attention (1 query vs N KVs) → output
Output:
1. inverse RoPE on attention output
2. o_a: BMM with wo_a (BF16)
3. o_b: NVFP4 GEMM with wo_b
"""
import torch
import torch.nn.functional as F
def apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim):
"""Apply GPT-J style RoPE. Works on (T, HD) or (T, NH, HD)."""
if rope_dim == 0 or x.numel() == 0:
return x
half = rope_dim // 2
cos = cos_sin_cache[positions, :half].to(x.dtype)
sin = cos_sin_cache[positions, half:2 * half].to(x.dtype)
if x.dim() == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
x_rope = x[..., nope_dim:].clone()
even = x_rope[..., 0::2]
odd = x_rope[..., 1::2]
out = x.clone()
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
return out
def apply_inv_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim):
"""Inverse GPT-J RoPE (sin → -sin)."""
if rope_dim == 0 or x.numel() == 0:
return x
half = rope_dim // 2
cos = cos_sin_cache[positions, :half].to(x.dtype)
sin = cos_sin_cache[positions, half:2 * half].to(x.dtype)
if x.dim() == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
x_rope = x[..., nope_dim:].clone()
even = x_rope[..., 0::2]
odd = x_rope[..., 1::2]
out = x.clone()
out[..., nope_dim:][..., 0::2] = even * cos + odd * sin
out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos
return out
# ── KV Cache Operations ──────────────────────────────────────────────
def kv_quantize_fp8(kv_bf16):
"""BF16 KV → fp8_e4m3 with per-token inverse scale."""
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
scale = fp8_max / amax
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
inv_scale = (amax / fp8_max).to(torch.bfloat16)
return kv_fp8, inv_scale
def kv_dequantize_fp8(kv_fp8, inv_scale):
"""fp8 KV → BF16."""
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
def paged_kv_write(kv_data, slot_mapping, cache, block_size):
"""Write KV into paged cache. Works for fp8 or bf16.
kv_data: (T, D) tensor to write
slot_mapping: (T,) slot indices
cache: (num_blocks, block_size, D) cache tensor
"""
for t in range(kv_data.shape[0]):
slot = slot_mapping[t].item()
block_idx = slot // block_size
offset = slot % block_size
if block_idx < cache.shape[0] and offset < cache.shape[1]:
cache[block_idx, offset] = kv_data[t]
def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
"""Read KV from paged cache."""
device = cache.device
kv = torch.zeros(num_tokens, head_dim, dtype=cache.dtype, device=device)
for t in range(num_tokens):
slot = slot_mapping[t].item()
block_idx = slot // block_size
offset = slot % block_size
if block_idx < cache.shape[0] and offset < cache.shape[1]:
kv[t] = cache[block_idx, offset]
return kv
# ── Attention ─────────────────────────────────────────────────────────
def causal_prefill_attention(q, kv, scale):
"""Full causal self-attention for prefill. q: (T, NH, HD), kv: (T, HD)."""
T, NH, HD = q.shape
q_t = q.permute(1, 0, 2) # (NH, T, HD)
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T, HD)
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale)
return out.permute(1, 0, 2) # (T, NH, HD)
def decode_attention(q, kv, scale):
"""Decode attention: 1 query vs N cached KVs.
q: (1, NH, HD) — single decode token
kv: (N, HD) — all cached KV (already with RoPE)
"""
NH = q.shape[1]
HD = q.shape[2]
q_t = q.permute(1, 0, 2) # (NH, 1, HD)
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, N, HD)
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
return out.permute(1, 0, 2) # (1, NH, HD)
def swa_attention(q, kv, positions, scale, window_size):
"""Sliding window attention.
q: (T, NH, HD) with RoPE
kv: (total_len, HD) — ALL cached KV with RoPE
positions: (T,) — absolute positions of the query tokens
"""
T, NH, HD = q.shape
total_len = kv.shape[0]
output = torch.zeros_like(q)
for t in range(T):
pos = positions[t].item()
window_start = max(0, pos - window_size + 1)
window_len = pos - window_start + 1
if window_len <= 0:
continue
kv_window = kv[window_start:pos + 1] # (window_len, HD)
q_t = q[t:t + 1] # (1, NH, HD)
output[t] = decode_attention(q_t, kv_window, scale).squeeze(0)
return output
# ── Full Pipeline ─────────────────────────────────────────────────────
def blackwell_attention_forward(
# Inputs
q, # (T, NH, HD) with RoPE already applied
kv, # (T, HD) kv_normed, RoPE'd — the NEW tokens' KV
positions, # (T,) absolute positions
# KV Cache
swa_kv_cache, # (num_blocks, block_size, HD) fp8 paged cache
swa_inv_scale, # (num_blocks * block_size, 1) per-token inv scale
slot_mapping, # (T,) slot indices for writing
block_size, # tokens per block
seq_lens, # (num_seqs,) total sequence lengths (prefill + history)
num_prefills, # number of prefill sequences
num_decode_tokens, # number of decode tokens
# Params
scale, # 1/sqrt(HD)
nope_dim, # 448
rope_dim, # 64
window_size, # 128
compress_ratio, # 0, 1, 4, or 128
cos_sin_cache, # (max_pos, rope_dim) for RoPE
attn_sink, # (NH,) sink weights
):
"""Full attention forward for Blackwell (SM100+).
This is what replaces vLLM's _attention_impl_blackwell.
Steps:
1. Quantize + write new KV to paged cache
2. Read ALL cached KV for each sequence
3. Attention (prefill: causal, decode: full)
4. Return attention output (T, NH, HD)
"""
T = q.shape[0]
NH = q.shape[1]
HD = q.shape[2]
device = q.device
# Step 1: Quantize new KV and write to cache
# kv already has RoPE applied (done by caller)
kv_fp8, kv_inv_scale = kv_quantize_fp8(kv)
paged_kv_write(kv_fp8, slot_mapping, swa_kv_cache, block_size)
# Write inv_scale to flat cache
for t in range(T):
slot = slot_mapping[t].item()
swa_inv_scale[slot] = kv_inv_scale[t]
# Step 2 & 3: Read cached KV and attend
# For simplicity in this initial version, we separate prefill and decode
output = torch.zeros(T, NH, HD, dtype=torch.bfloat16, device=device)
if num_decode_tokens > 0:
# Decode tokens: each needs ALL prior KV from cache
for t in range(num_decode_tokens):
pos = positions[t].item()
# Read all KV from position 0 to pos
all_slots = torch.arange(pos + 1, dtype=torch.int64, device=device)
kv_cached_fp8 = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, HD)
kv_inv_scales = swa_inv_scale[all_slots]
kv_cached = kv_dequantize_fp8(kv_cached_fp8, kv_inv_scales)
# Apply SWA window
window_start = max(0, pos - window_size + 1)
kv_window = kv_cached[window_start:]
q_t = q[t:t + 1] # (1, NH, HD)
output[t] = decode_attention(q_t, kv_window, scale).squeeze(0)
if num_prefills > 0:
# Prefill tokens: causal attention using the NEW kv (not from cache,
# since all KV is available from the current forward pass)
# But we DO write to cache for future decode steps
prefill_slice = slice(num_decode_tokens, T)
output[prefill_slice] = causal_prefill_attention(
q[prefill_slice], kv[prefill_slice], scale
)
return output