""" DeepSeek-V4 Blackwell Attention — Our own kernel. Replaces vLLM's broken FlashMLA Blackwell path with a proper KV cache-based attention pipeline. Does NOT depend on FlashMLA, fp8_ds_mla, or any vLLM fused CUDA kernel. Architecture: - KV: (T, HD=512) single head latent, shared across all 128 Q heads - KV Cache: fp8_e4m3 paged cache with per-token inverse scale - RoPE: GPT-J style, applied to Q and KV before caching - Attention: BF16 (NVFP4 is too lossy for Q×K^T, cosine 0.86) - CSA/HCA: Compressed KV for sparse attention (compress_ratio 4 or 128) - SWA: Sliding window attention (compress_ratio 0/1) Pipeline: Prefill: 1. hidden → q_a_proj → q_norm → q_b_proj → (T, NH, HD) → RoPE on Q 2. hidden → kv_proj → kv_norm → (T, HD) → RoPE → fp8 quant → write to paged cache 3. Read all cached KV → BF16 causal attention → output Decode: 1. Same projections as prefill 2. Write new KV to cache 3. Read ALL cached KV → BF16 attention (1 query vs N KVs) → output Output: 1. inverse RoPE on attention output 2. o_a: BMM with wo_a (BF16) 3. o_b: NVFP4 GEMM with wo_b """ import torch import torch.nn.functional as F def apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim): """Apply GPT-J style RoPE. Works on (T, HD) or (T, NH, HD).""" if rope_dim == 0 or x.numel() == 0: return x half = rope_dim // 2 cos = cos_sin_cache[positions, :half].to(x.dtype) sin = cos_sin_cache[positions, half:2 * half].to(x.dtype) if x.dim() == 3: cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) x_rope = x[..., nope_dim:].clone() even = x_rope[..., 0::2] odd = x_rope[..., 1::2] out = x.clone() out[..., nope_dim:][..., 0::2] = even * cos - odd * sin out[..., nope_dim:][..., 1::2] = even * sin + odd * cos return out def apply_inv_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim): """Inverse GPT-J RoPE (sin → -sin).""" if rope_dim == 0 or x.numel() == 0: return x half = rope_dim // 2 cos = cos_sin_cache[positions, :half].to(x.dtype) sin = cos_sin_cache[positions, half:2 * half].to(x.dtype) if x.dim() == 3: cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) x_rope = x[..., nope_dim:].clone() even = x_rope[..., 0::2] odd = x_rope[..., 1::2] out = x.clone() out[..., nope_dim:][..., 0::2] = even * cos + odd * sin out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos return out # ── KV Cache Operations ────────────────────────────────────────────── def kv_quantize_fp8(kv_bf16): """BF16 KV → fp8_e4m3 with per-token inverse scale.""" amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device) scale = fp8_max / amax kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn) inv_scale = (amax / fp8_max).to(torch.bfloat16) return kv_fp8, inv_scale def kv_dequantize_fp8(kv_fp8, inv_scale): """fp8 KV → BF16.""" return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16) def paged_kv_write(kv_data, slot_mapping, cache, block_size): """Write KV into paged cache. Works for fp8 or bf16. kv_data: (T, D) tensor to write slot_mapping: (T,) slot indices cache: (num_blocks, block_size, D) cache tensor """ for t in range(kv_data.shape[0]): slot = slot_mapping[t].item() block_idx = slot // block_size offset = slot % block_size if block_idx < cache.shape[0] and offset < cache.shape[1]: cache[block_idx, offset] = kv_data[t] def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim): """Read KV from paged cache.""" device = cache.device kv = torch.zeros(num_tokens, head_dim, dtype=cache.dtype, device=device) for t in range(num_tokens): slot = slot_mapping[t].item() block_idx = slot // block_size offset = slot % block_size if block_idx < cache.shape[0] and offset < cache.shape[1]: kv[t] = cache[block_idx, offset] return kv # ── Attention ───────────────────────────────────────────────────────── def causal_prefill_attention(q, kv, scale): """Full causal self-attention for prefill. q: (T, NH, HD), kv: (T, HD).""" T, NH, HD = q.shape q_t = q.permute(1, 0, 2) # (NH, T, HD) kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T, HD) out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale) return out.permute(1, 0, 2) # (T, NH, HD) def decode_attention(q, kv, scale): """Decode attention: 1 query vs N cached KVs. q: (1, NH, HD) — single decode token kv: (N, HD) — all cached KV (already with RoPE) """ NH = q.shape[1] HD = q.shape[2] q_t = q.permute(1, 0, 2) # (NH, 1, HD) kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, N, HD) out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale) return out.permute(1, 0, 2) # (1, NH, HD) def swa_attention(q, kv, positions, scale, window_size): """Sliding window attention. q: (T, NH, HD) with RoPE kv: (total_len, HD) — ALL cached KV with RoPE positions: (T,) — absolute positions of the query tokens """ T, NH, HD = q.shape total_len = kv.shape[0] output = torch.zeros_like(q) for t in range(T): pos = positions[t].item() window_start = max(0, pos - window_size + 1) window_len = pos - window_start + 1 if window_len <= 0: continue kv_window = kv[window_start:pos + 1] # (window_len, HD) q_t = q[t:t + 1] # (1, NH, HD) output[t] = decode_attention(q_t, kv_window, scale).squeeze(0) return output # ── Full Pipeline ───────────────────────────────────────────────────── def blackwell_attention_forward( # Inputs q, # (T, NH, HD) with RoPE already applied kv, # (T, HD) kv_normed, RoPE'd — the NEW tokens' KV positions, # (T,) absolute positions # KV Cache swa_kv_cache, # (num_blocks, block_size, HD) fp8 paged cache swa_inv_scale, # (num_blocks * block_size, 1) per-token inv scale slot_mapping, # (T,) slot indices for writing block_size, # tokens per block seq_lens, # (num_seqs,) total sequence lengths (prefill + history) num_prefills, # number of prefill sequences num_decode_tokens, # number of decode tokens # Params scale, # 1/sqrt(HD) nope_dim, # 448 rope_dim, # 64 window_size, # 128 compress_ratio, # 0, 1, 4, or 128 cos_sin_cache, # (max_pos, rope_dim) for RoPE attn_sink, # (NH,) sink weights ): """Full attention forward for Blackwell (SM100+). This is what replaces vLLM's _attention_impl_blackwell. Steps: 1. Quantize + write new KV to paged cache 2. Read ALL cached KV for each sequence 3. Attention (prefill: causal, decode: full) 4. Return attention output (T, NH, HD) """ T = q.shape[0] NH = q.shape[1] HD = q.shape[2] device = q.device # Step 1: Quantize new KV and write to cache # kv already has RoPE applied (done by caller) kv_fp8, kv_inv_scale = kv_quantize_fp8(kv) paged_kv_write(kv_fp8, slot_mapping, swa_kv_cache, block_size) # Write inv_scale to flat cache for t in range(T): slot = slot_mapping[t].item() swa_inv_scale[slot] = kv_inv_scale[t] # Step 2 & 3: Read cached KV and attend # For simplicity in this initial version, we separate prefill and decode output = torch.zeros(T, NH, HD, dtype=torch.bfloat16, device=device) if num_decode_tokens > 0: # Decode tokens: each needs ALL prior KV from cache for t in range(num_decode_tokens): pos = positions[t].item() # Read all KV from position 0 to pos all_slots = torch.arange(pos + 1, dtype=torch.int64, device=device) kv_cached_fp8 = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, HD) kv_inv_scales = swa_inv_scale[all_slots] kv_cached = kv_dequantize_fp8(kv_cached_fp8, kv_inv_scales) # Apply SWA window window_start = max(0, pos - window_size + 1) kv_window = kv_cached[window_start:] q_t = q[t:t + 1] # (1, NH, HD) output[t] = decode_attention(q_t, kv_window, scale).squeeze(0) if num_prefills > 0: # Prefill tokens: causal attention using the NEW kv (not from cache, # since all KV is available from the current forward pass) # But we DO write to cache for future decode steps prefill_slice = slice(num_decode_tokens, T) output[prefill_slice] = causal_prefill_attention( q[prefill_slice], kv[prefill_slice], scale ) return output