Files
nvfp4-megamoe-kernel/dsv4/reference/attention.py

248 lines
9.1 KiB
Python
Raw Normal View History

"""
DeepSeek-V4 Blackwell Attention Our own kernel.
Replaces vLLM's broken FlashMLA Blackwell path with a proper KV cache-based
attention pipeline. Does NOT depend on FlashMLA, fp8_ds_mla, or any vLLM
fused CUDA kernel.
Architecture:
- KV: (T, HD=512) single head latent, shared across all 128 Q heads
- KV Cache: fp8_e4m3 paged cache with per-token inverse scale
- RoPE: GPT-J style, applied to Q and KV before caching
- Attention: BF16 (NVFP4 is too lossy for Q×K^T, cosine 0.86)
- CSA/HCA: Compressed KV for sparse attention (compress_ratio 4 or 128)
- SWA: Sliding window attention (compress_ratio 0/1)
Pipeline:
Prefill:
1. hidden q_a_proj q_norm q_b_proj (T, NH, HD) RoPE on Q
2. hidden kv_proj kv_norm (T, HD) RoPE fp8 quant write to paged cache
3. Read all cached KV BF16 causal attention output
Decode:
1. Same projections as prefill
2. Write new KV to cache
3. Read ALL cached KV BF16 attention (1 query vs N KVs) output
Output:
1. inverse RoPE on attention output
2. o_a: BMM with wo_a (BF16)
3. o_b: NVFP4 GEMM with wo_b
"""
import torch
import torch.nn.functional as F
def apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim):
"""Apply GPT-J style RoPE. Works on (T, HD) or (T, NH, HD)."""
if rope_dim == 0 or x.numel() == 0:
return x
half = rope_dim // 2
cos = cos_sin_cache[positions, :half].to(x.dtype)
sin = cos_sin_cache[positions, half:2 * half].to(x.dtype)
if x.dim() == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
x_rope = x[..., nope_dim:].clone()
even = x_rope[..., 0::2]
odd = x_rope[..., 1::2]
out = x.clone()
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
return out
def apply_inv_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim):
"""Inverse GPT-J RoPE (sin → -sin)."""
if rope_dim == 0 or x.numel() == 0:
return x
half = rope_dim // 2
cos = cos_sin_cache[positions, :half].to(x.dtype)
sin = cos_sin_cache[positions, half:2 * half].to(x.dtype)
if x.dim() == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
x_rope = x[..., nope_dim:].clone()
even = x_rope[..., 0::2]
odd = x_rope[..., 1::2]
out = x.clone()
out[..., nope_dim:][..., 0::2] = even * cos + odd * sin
out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos
return out
# ── KV Cache Operations ──────────────────────────────────────────────
def kv_quantize_fp8(kv_bf16):
"""BF16 KV → fp8_e4m3 with per-token inverse scale."""
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
scale = fp8_max / amax
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
inv_scale = (amax / fp8_max).to(torch.bfloat16)
return kv_fp8, inv_scale
def kv_dequantize_fp8(kv_fp8, inv_scale):
"""fp8 KV → BF16."""
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
def paged_kv_write(kv_data, slot_mapping, cache, block_size):
"""Write KV into paged cache. Works for fp8 or bf16.
kv_data: (T, D) tensor to write
slot_mapping: (T,) slot indices
cache: (num_blocks, block_size, D) cache tensor
"""
for t in range(kv_data.shape[0]):
slot = slot_mapping[t].item()
block_idx = slot // block_size
offset = slot % block_size
if block_idx < cache.shape[0] and offset < cache.shape[1]:
cache[block_idx, offset] = kv_data[t]
def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
"""Read KV from paged cache."""
device = cache.device
kv = torch.zeros(num_tokens, head_dim, dtype=cache.dtype, device=device)
for t in range(num_tokens):
slot = slot_mapping[t].item()
block_idx = slot // block_size
offset = slot % block_size
if block_idx < cache.shape[0] and offset < cache.shape[1]:
kv[t] = cache[block_idx, offset]
return kv
# ── Attention ─────────────────────────────────────────────────────────
def causal_prefill_attention(q, kv, scale):
"""Full causal self-attention for prefill. q: (T, NH, HD), kv: (T, HD)."""
T, NH, HD = q.shape
q_t = q.permute(1, 0, 2) # (NH, T, HD)
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T, HD)
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale)
return out.permute(1, 0, 2) # (T, NH, HD)
def decode_attention(q, kv, scale):
"""Decode attention: 1 query vs N cached KVs.
q: (1, NH, HD) single decode token
kv: (N, HD) all cached KV (already with RoPE)
"""
NH = q.shape[1]
HD = q.shape[2]
q_t = q.permute(1, 0, 2) # (NH, 1, HD)
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, N, HD)
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
return out.permute(1, 0, 2) # (1, NH, HD)
def swa_attention(q, kv, positions, scale, window_size):
"""Sliding window attention.
q: (T, NH, HD) with RoPE
kv: (total_len, HD) ALL cached KV with RoPE
positions: (T,) absolute positions of the query tokens
"""
T, NH, HD = q.shape
total_len = kv.shape[0]
output = torch.zeros_like(q)
for t in range(T):
pos = positions[t].item()
window_start = max(0, pos - window_size + 1)
window_len = pos - window_start + 1
if window_len <= 0:
continue
kv_window = kv[window_start:pos + 1] # (window_len, HD)
q_t = q[t:t + 1] # (1, NH, HD)
output[t] = decode_attention(q_t, kv_window, scale).squeeze(0)
return output
# ── Full Pipeline ─────────────────────────────────────────────────────
def blackwell_attention_forward(
# Inputs
q, # (T, NH, HD) with RoPE already applied
kv, # (T, HD) kv_normed, RoPE'd — the NEW tokens' KV
positions, # (T,) absolute positions
# KV Cache
swa_kv_cache, # (num_blocks, block_size, HD) fp8 paged cache
swa_inv_scale, # (num_blocks * block_size, 1) per-token inv scale
slot_mapping, # (T,) slot indices for writing
block_size, # tokens per block
seq_lens, # (num_seqs,) total sequence lengths (prefill + history)
num_prefills, # number of prefill sequences
num_decode_tokens, # number of decode tokens
# Params
scale, # 1/sqrt(HD)
nope_dim, # 448
rope_dim, # 64
window_size, # 128
compress_ratio, # 0, 1, 4, or 128
cos_sin_cache, # (max_pos, rope_dim) for RoPE
attn_sink, # (NH,) sink weights
):
"""Full attention forward for Blackwell (SM100+).
This is what replaces vLLM's _attention_impl_blackwell.
Steps:
1. Quantize + write new KV to paged cache
2. Read ALL cached KV for each sequence
3. Attention (prefill: causal, decode: full)
4. Return attention output (T, NH, HD)
"""
T = q.shape[0]
NH = q.shape[1]
HD = q.shape[2]
device = q.device
# Step 1: Quantize new KV and write to cache
# kv already has RoPE applied (done by caller)
kv_fp8, kv_inv_scale = kv_quantize_fp8(kv)
paged_kv_write(kv_fp8, slot_mapping, swa_kv_cache, block_size)
# Write inv_scale to flat cache
for t in range(T):
slot = slot_mapping[t].item()
swa_inv_scale[slot] = kv_inv_scale[t]
# Step 2 & 3: Read cached KV and attend
# For simplicity in this initial version, we separate prefill and decode
output = torch.zeros(T, NH, HD, dtype=torch.bfloat16, device=device)
if num_decode_tokens > 0:
# Decode tokens: each needs ALL prior KV from cache
for t in range(num_decode_tokens):
pos = positions[t].item()
# Read all KV from position 0 to pos
all_slots = torch.arange(pos + 1, dtype=torch.int64, device=device)
kv_cached_fp8 = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, HD)
kv_inv_scales = swa_inv_scale[all_slots]
kv_cached = kv_dequantize_fp8(kv_cached_fp8, kv_inv_scales)
# Apply SWA window
window_start = max(0, pos - window_size + 1)
kv_window = kv_cached[window_start:]
q_t = q[t:t + 1] # (1, NH, HD)
output[t] = decode_attention(q_t, kv_window, scale).squeeze(0)
if num_prefills > 0:
# Prefill tokens: causal attention using the NEW kv (not from cache,
# since all KV is available from the current forward pass)
# But we DO write to cache for future decode steps
prefill_slice = slice(num_decode_tokens, T)
output[prefill_slice] = causal_prefill_attention(
q[prefill_slice], kv[prefill_slice], scale
)
return output