- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
248 lines
9.1 KiB
Python
248 lines
9.1 KiB
Python
"""
|
||
DeepSeek-V4 Blackwell Attention — Our own kernel.
|
||
|
||
Replaces vLLM's broken FlashMLA Blackwell path with a proper KV cache-based
|
||
attention pipeline. Does NOT depend on FlashMLA, fp8_ds_mla, or any vLLM
|
||
fused CUDA kernel.
|
||
|
||
Architecture:
|
||
- KV: (T, HD=512) single head latent, shared across all 128 Q heads
|
||
- KV Cache: fp8_e4m3 paged cache with per-token inverse scale
|
||
- RoPE: GPT-J style, applied to Q and KV before caching
|
||
- Attention: BF16 (NVFP4 is too lossy for Q×K^T, cosine 0.86)
|
||
- CSA/HCA: Compressed KV for sparse attention (compress_ratio 4 or 128)
|
||
- SWA: Sliding window attention (compress_ratio 0/1)
|
||
|
||
Pipeline:
|
||
Prefill:
|
||
1. hidden → q_a_proj → q_norm → q_b_proj → (T, NH, HD) → RoPE on Q
|
||
2. hidden → kv_proj → kv_norm → (T, HD) → RoPE → fp8 quant → write to paged cache
|
||
3. Read all cached KV → BF16 causal attention → output
|
||
|
||
Decode:
|
||
1. Same projections as prefill
|
||
2. Write new KV to cache
|
||
3. Read ALL cached KV → BF16 attention (1 query vs N KVs) → output
|
||
|
||
Output:
|
||
1. inverse RoPE on attention output
|
||
2. o_a: BMM with wo_a (BF16)
|
||
3. o_b: NVFP4 GEMM with wo_b
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
|
||
def apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim):
|
||
"""Apply GPT-J style RoPE. Works on (T, HD) or (T, NH, HD)."""
|
||
if rope_dim == 0 or x.numel() == 0:
|
||
return x
|
||
half = rope_dim // 2
|
||
cos = cos_sin_cache[positions, :half].to(x.dtype)
|
||
sin = cos_sin_cache[positions, half:2 * half].to(x.dtype)
|
||
if x.dim() == 3:
|
||
cos = cos.unsqueeze(1)
|
||
sin = sin.unsqueeze(1)
|
||
x_rope = x[..., nope_dim:].clone()
|
||
even = x_rope[..., 0::2]
|
||
odd = x_rope[..., 1::2]
|
||
out = x.clone()
|
||
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
|
||
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
|
||
return out
|
||
|
||
|
||
def apply_inv_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim):
|
||
"""Inverse GPT-J RoPE (sin → -sin)."""
|
||
if rope_dim == 0 or x.numel() == 0:
|
||
return x
|
||
half = rope_dim // 2
|
||
cos = cos_sin_cache[positions, :half].to(x.dtype)
|
||
sin = cos_sin_cache[positions, half:2 * half].to(x.dtype)
|
||
if x.dim() == 3:
|
||
cos = cos.unsqueeze(1)
|
||
sin = sin.unsqueeze(1)
|
||
x_rope = x[..., nope_dim:].clone()
|
||
even = x_rope[..., 0::2]
|
||
odd = x_rope[..., 1::2]
|
||
out = x.clone()
|
||
out[..., nope_dim:][..., 0::2] = even * cos + odd * sin
|
||
out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos
|
||
return out
|
||
|
||
|
||
# ── KV Cache Operations ──────────────────────────────────────────────
|
||
|
||
def kv_quantize_fp8(kv_bf16):
|
||
"""BF16 KV → fp8_e4m3 with per-token inverse scale."""
|
||
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
|
||
scale = fp8_max / amax
|
||
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
|
||
inv_scale = (amax / fp8_max).to(torch.bfloat16)
|
||
return kv_fp8, inv_scale
|
||
|
||
|
||
def kv_dequantize_fp8(kv_fp8, inv_scale):
|
||
"""fp8 KV → BF16."""
|
||
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
|
||
|
||
|
||
def paged_kv_write(kv_data, slot_mapping, cache, block_size):
|
||
"""Write KV into paged cache. Works for fp8 or bf16.
|
||
|
||
kv_data: (T, D) tensor to write
|
||
slot_mapping: (T,) slot indices
|
||
cache: (num_blocks, block_size, D) cache tensor
|
||
"""
|
||
for t in range(kv_data.shape[0]):
|
||
slot = slot_mapping[t].item()
|
||
block_idx = slot // block_size
|
||
offset = slot % block_size
|
||
if block_idx < cache.shape[0] and offset < cache.shape[1]:
|
||
cache[block_idx, offset] = kv_data[t]
|
||
|
||
|
||
def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
|
||
"""Read KV from paged cache."""
|
||
device = cache.device
|
||
kv = torch.zeros(num_tokens, head_dim, dtype=cache.dtype, device=device)
|
||
for t in range(num_tokens):
|
||
slot = slot_mapping[t].item()
|
||
block_idx = slot // block_size
|
||
offset = slot % block_size
|
||
if block_idx < cache.shape[0] and offset < cache.shape[1]:
|
||
kv[t] = cache[block_idx, offset]
|
||
return kv
|
||
|
||
|
||
# ── Attention ─────────────────────────────────────────────────────────
|
||
|
||
def causal_prefill_attention(q, kv, scale):
|
||
"""Full causal self-attention for prefill. q: (T, NH, HD), kv: (T, HD)."""
|
||
T, NH, HD = q.shape
|
||
q_t = q.permute(1, 0, 2) # (NH, T, HD)
|
||
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T, HD)
|
||
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale)
|
||
return out.permute(1, 0, 2) # (T, NH, HD)
|
||
|
||
|
||
def decode_attention(q, kv, scale):
|
||
"""Decode attention: 1 query vs N cached KVs.
|
||
|
||
q: (1, NH, HD) — single decode token
|
||
kv: (N, HD) — all cached KV (already with RoPE)
|
||
"""
|
||
NH = q.shape[1]
|
||
HD = q.shape[2]
|
||
q_t = q.permute(1, 0, 2) # (NH, 1, HD)
|
||
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, N, HD)
|
||
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
|
||
return out.permute(1, 0, 2) # (1, NH, HD)
|
||
|
||
|
||
def swa_attention(q, kv, positions, scale, window_size):
|
||
"""Sliding window attention.
|
||
|
||
q: (T, NH, HD) with RoPE
|
||
kv: (total_len, HD) — ALL cached KV with RoPE
|
||
positions: (T,) — absolute positions of the query tokens
|
||
"""
|
||
T, NH, HD = q.shape
|
||
total_len = kv.shape[0]
|
||
output = torch.zeros_like(q)
|
||
|
||
for t in range(T):
|
||
pos = positions[t].item()
|
||
window_start = max(0, pos - window_size + 1)
|
||
window_len = pos - window_start + 1
|
||
if window_len <= 0:
|
||
continue
|
||
kv_window = kv[window_start:pos + 1] # (window_len, HD)
|
||
q_t = q[t:t + 1] # (1, NH, HD)
|
||
output[t] = decode_attention(q_t, kv_window, scale).squeeze(0)
|
||
|
||
return output
|
||
|
||
|
||
# ── Full Pipeline ─────────────────────────────────────────────────────
|
||
|
||
def blackwell_attention_forward(
|
||
# Inputs
|
||
q, # (T, NH, HD) with RoPE already applied
|
||
kv, # (T, HD) kv_normed, RoPE'd — the NEW tokens' KV
|
||
positions, # (T,) absolute positions
|
||
# KV Cache
|
||
swa_kv_cache, # (num_blocks, block_size, HD) fp8 paged cache
|
||
swa_inv_scale, # (num_blocks * block_size, 1) per-token inv scale
|
||
slot_mapping, # (T,) slot indices for writing
|
||
block_size, # tokens per block
|
||
seq_lens, # (num_seqs,) total sequence lengths (prefill + history)
|
||
num_prefills, # number of prefill sequences
|
||
num_decode_tokens, # number of decode tokens
|
||
# Params
|
||
scale, # 1/sqrt(HD)
|
||
nope_dim, # 448
|
||
rope_dim, # 64
|
||
window_size, # 128
|
||
compress_ratio, # 0, 1, 4, or 128
|
||
cos_sin_cache, # (max_pos, rope_dim) for RoPE
|
||
attn_sink, # (NH,) sink weights
|
||
):
|
||
"""Full attention forward for Blackwell (SM100+).
|
||
|
||
This is what replaces vLLM's _attention_impl_blackwell.
|
||
|
||
Steps:
|
||
1. Quantize + write new KV to paged cache
|
||
2. Read ALL cached KV for each sequence
|
||
3. Attention (prefill: causal, decode: full)
|
||
4. Return attention output (T, NH, HD)
|
||
"""
|
||
T = q.shape[0]
|
||
NH = q.shape[1]
|
||
HD = q.shape[2]
|
||
device = q.device
|
||
|
||
# Step 1: Quantize new KV and write to cache
|
||
# kv already has RoPE applied (done by caller)
|
||
kv_fp8, kv_inv_scale = kv_quantize_fp8(kv)
|
||
paged_kv_write(kv_fp8, slot_mapping, swa_kv_cache, block_size)
|
||
# Write inv_scale to flat cache
|
||
for t in range(T):
|
||
slot = slot_mapping[t].item()
|
||
swa_inv_scale[slot] = kv_inv_scale[t]
|
||
|
||
# Step 2 & 3: Read cached KV and attend
|
||
# For simplicity in this initial version, we separate prefill and decode
|
||
output = torch.zeros(T, NH, HD, dtype=torch.bfloat16, device=device)
|
||
|
||
if num_decode_tokens > 0:
|
||
# Decode tokens: each needs ALL prior KV from cache
|
||
for t in range(num_decode_tokens):
|
||
pos = positions[t].item()
|
||
# Read all KV from position 0 to pos
|
||
all_slots = torch.arange(pos + 1, dtype=torch.int64, device=device)
|
||
kv_cached_fp8 = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, HD)
|
||
kv_inv_scales = swa_inv_scale[all_slots]
|
||
kv_cached = kv_dequantize_fp8(kv_cached_fp8, kv_inv_scales)
|
||
|
||
# Apply SWA window
|
||
window_start = max(0, pos - window_size + 1)
|
||
kv_window = kv_cached[window_start:]
|
||
|
||
q_t = q[t:t + 1] # (1, NH, HD)
|
||
output[t] = decode_attention(q_t, kv_window, scale).squeeze(0)
|
||
|
||
if num_prefills > 0:
|
||
# Prefill tokens: causal attention using the NEW kv (not from cache,
|
||
# since all KV is available from the current forward pass)
|
||
# But we DO write to cache for future decode steps
|
||
prefill_slice = slice(num_decode_tokens, T)
|
||
output[prefill_slice] = causal_prefill_attention(
|
||
q[prefill_slice], kv[prefill_slice], scale
|
||
)
|
||
|
||
return output
|