Files

652 lines
26 KiB
Python
Raw Permalink Normal View History

"""
CSA / HCA Token-Level Compressor for DeepSeek-V4.
Implements Section 2.3 of the DeepSeek-V4 paper exactly:
- CSA (m=4): overlapping weighted sum over 2m hidden states per block
- HCA (m'=128): non-overlapping weighted sum over m' hidden states per block
Both produce compressed KV entries C^Comp R^{n/m × c} where each entry
is a weighted sum of hidden states using softmax-normalised gate weights.
CSA additionally produces compressed indexer keys K^IComp R^{n/m × c_I}
for the Lightning Indexer (top-k sparse selection).
V4-Pro reference dimensions (Section 4.2.1):
d = 7168 hidden dim
c = 512 head dim (CSA and HCA both)
m = 4 CSA compression ratio
m' = 128 HCA compression ratio
c_I = 128 indexer head dim
n_I_h = 64 num indexer query heads
n_win = 128 sliding window size (separate, not handled here)
rope_dim = 64 partial RoPE on last 64 dims of each head
Design notes
------------
* BF16 matmuls throughout swap the _proj() calls for your NVFP4 GEMMs.
* No batch dimension: one sequence at a time (matching decode latency path).
* CompressorState carries the incomplete-block tail and the previous block's
raw projections needed for the CSA overlap.
* Partial RoPE is applied to the last rope_dim=64 dims of C^Comp before
the entry is stored in the compressed KV cache, using the representative
position = last token of the block. The inverse-RoPE on attention outputs
is handled by your attention kernel (already in blackwell_attention.py).
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
@dataclass
class CompressorState:
"""
Per-sequence mutable state for the compressor.
tail_hidden hidden states that have arrived but don't yet fill a
complete compression block. Shape (tail_len, d),
0 <= tail_len < m.
prev_hidden the m hidden states from the previous complete block.
Needed for the CSA overlap (C^b / Z^b projections).
None before the first block is committed.
Not used by HCA (no overlap).
compressed_kv accumulated C^Comp entries, shape (n_blocks, c).
compressed_indexer_kv accumulated K^IComp entries, shape (n_blocks, c_I).
None for HCA layers.
"""
tail_hidden: Optional[torch.Tensor] = None # (tail_len, d)
prev_hidden: Optional[torch.Tensor] = None # (m, d) CSA only
compressed_kv: Optional[torch.Tensor] = None # (n_blocks, c)
compressed_indexer_kv: Optional[torch.Tensor] = None # (n_blocks, c_I)
num_blocks: int = 0
def reset(self):
self.tail_hidden = None
self.prev_hidden = None
self.compressed_kv = None
self.compressed_indexer_kv = None
self.num_blocks = 0
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _apply_partial_rope(
x: torch.Tensor, # (..., c)
positions: torch.Tensor, # (...,) int64
cos_sin_cache: torch.Tensor, # (max_pos, rope_dim)
nope_dim: int,
rope_dim: int,
) -> torch.Tensor:
"""GPT-J style RoPE on the last rope_dim dimensions only."""
if rope_dim == 0:
return x
half = rope_dim // 2
cos = cos_sin_cache[positions, :half].to(x.dtype) # (..., half)
sin = cos_sin_cache[positions, half:].to(x.dtype) # (..., half)
out = x.clone()
rope_part = out[..., nope_dim:] # (..., rope_dim)
even = rope_part[..., 0::2]
odd = rope_part[..., 1::2]
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
return out
def _proj(x: torch.Tensor, W: torch.Tensor) -> torch.Tensor:
"""
Linear projection: x @ W.
x: (..., d) W: (d, out_dim) (..., out_dim)
*** SWAP THIS FOR YOUR NVFP4 GEMM. ***
The W tensor would become your NVFP4 weight + scale_b + gsb triple,
and x would be quantised to NVFP4 activation before the call.
"""
return x.to(W.dtype) @ W
# ---------------------------------------------------------------------------
# CSA compressor
# ---------------------------------------------------------------------------
class CSACompressor:
"""
Compressed Sparse Attention token-level compressor.
Paper equations (11) and (12), Section 2.3.1:
C^a = H · W^a_KV Z^a = H · W^a_Z (current block)
C^b = H · W^b_KV Z^b = H · W^b_Z (prev-block overlap)
For compressed block i (0-indexed):
[S^a ; S^b] = softmax_row( [Z^a_{cur} + B^a ; Z^b_{prev} + B^b] )
C^Comp_i = Σ S^a_j C^a_j + Σ S^b_j C^b_j
When i=0: Z^b / C^b are padded with -inf / 0 so only Z^a contributes.
The same compression is applied independently to produce indexer keys,
using separate projections W^I_KV, W^I_Z, B^I_a, B^I_b.
"""
def __init__(
self,
hidden_dim: int = 7168, # d
head_dim: int = 512, # c
compress_ratio: int = 4, # m
indexer_head_dim: int = 128, # c_I
num_indexer_heads: int = 64, # n_I_h (not used in compressor itself)
nope_dim: int = 448, # c - rope_dim = 512 - 64
rope_dim: int = 64,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.d = hidden_dim
self.c = head_dim
self.m = compress_ratio
self.c_I = indexer_head_dim
self.n_I_h = num_indexer_heads
self.nope = nope_dim
self.rope = rope_dim
self.device = device
self.dtype = dtype
# ── Main KV projection weights ──────────────────────────────
# W^a_{KV}, W^b_{KV}: (d, c)
# W^a_Z, W^b_Z: (d, c)
self.W_a_KV = self._param(hidden_dim, head_dim)
self.W_b_KV = self._param(hidden_dim, head_dim)
self.W_a_Z = self._param(hidden_dim, head_dim)
self.W_b_Z = self._param(hidden_dim, head_dim)
# Positional biases B^a, B^b: (m, c) — learnable per-position offsets
# added to the gate logits before softmax.
self.B_a = self._param(compress_ratio, head_dim)
self.B_b = self._param(compress_ratio, head_dim)
# ── Indexer key projection weights ──────────────────────────
# Same overlap structure, separate projections, output dim c_I.
self.W_I_a_KV = self._param(hidden_dim, indexer_head_dim)
self.W_I_b_KV = self._param(hidden_dim, indexer_head_dim)
self.W_I_a_Z = self._param(hidden_dim, indexer_head_dim)
self.W_I_b_Z = self._param(hidden_dim, indexer_head_dim)
self.B_I_a = self._param(compress_ratio, indexer_head_dim)
self.B_I_b = self._param(compress_ratio, indexer_head_dim)
def _param(self, *shape) -> torch.Tensor:
"""Uninitialised placeholder — replace with checkpoint-loaded tensor."""
return torch.empty(*shape, dtype=self.dtype, device=self.device)
def load_weights(
self,
W_a_KV, W_b_KV, W_a_Z, W_b_Z, B_a, B_b,
W_I_a_KV, W_I_b_KV, W_I_a_Z, W_I_b_Z, B_I_a, B_I_b,
):
"""Assign weights from checkpoint. All tensors moved to device/dtype."""
def _cvt(t): return t.to(device=self.device, dtype=self.dtype)
self.W_a_KV = _cvt(W_a_KV); self.W_b_KV = _cvt(W_b_KV)
self.W_a_Z = _cvt(W_a_Z); self.W_b_Z = _cvt(W_b_Z)
self.B_a = _cvt(B_a); self.B_b = _cvt(B_b)
self.W_I_a_KV = _cvt(W_I_a_KV); self.W_I_b_KV = _cvt(W_I_b_KV)
self.W_I_a_Z = _cvt(W_I_a_Z); self.W_I_b_Z = _cvt(W_I_b_Z)
self.B_I_a = _cvt(B_I_a); self.B_I_b = _cvt(B_I_b)
# ----------------------------------------------------------------
# Core: compress one block of m hidden states
# ----------------------------------------------------------------
def _compress_block(
self,
cur_hidden: torch.Tensor, # (m, d) current block
prev_hidden: Optional[torch.Tensor], # (m, d) or None if block 0
cos_sin_cache: Optional[torch.Tensor],
block_end_pos: int, # position of last token in block
for_indexer: bool = False,
) -> torch.Tensor:
"""
Compress one block into a single C^Comp entry.
Returns: (c,) or (c_I,) compressed entry.
The overlap computation (equations 11-12):
Z_cat = [Z^a + B^a ; Z^b + B^b] shape (2m, c) or (m, c) when block 0
S = softmax(Z_cat, dim=0) normalise over the 2m position axis
C_out = (S[:m] * C^a).sum(0) + (S[m:] * C^b).sum(0)
"""
m = self.m
assert cur_hidden.shape[0] == m
if for_indexer:
W_a_KV, W_b_KV = self.W_I_a_KV, self.W_I_b_KV
W_a_Z, W_b_Z = self.W_I_a_Z, self.W_I_b_Z
B_a, B_b = self.B_I_a, self.B_I_b
else:
W_a_KV, W_b_KV = self.W_a_KV, self.W_b_KV
W_a_Z, W_b_Z = self.W_a_Z, self.W_b_Z
B_a, B_b = self.B_a, self.B_b
# Current block projections: (m, c)
C_a = _proj(cur_hidden, W_a_KV) # KV candidates
Z_a = _proj(cur_hidden, W_a_Z) # gate logits
if prev_hidden is None:
# Block 0: no previous block → softmax over m entries only
# (paper pads Z^b with -inf, C^b with 0)
Z_cat = Z_a + B_a # (m, c)
S = F.softmax(Z_cat.float(), dim=0).to(self.dtype) # (m, c)
C_out = (S * C_a).sum(dim=0) # (c,)
else:
# Blocks 1..: overlap with previous block
C_b = _proj(prev_hidden, W_b_KV) # (m, c)
Z_b = _proj(prev_hidden, W_b_Z) # (m, c)
# Concatenate along the position axis → (2m, c)
Z_cat = torch.cat([Z_a + B_a, Z_b + B_b], dim=0)
S = F.softmax(Z_cat.float(), dim=0).to(self.dtype) # (2m, c)
S_a, S_b = S[:m], S[m:] # each (m, c)
C_out = (S_a * C_a).sum(dim=0) + (S_b * C_b).sum(dim=0) # (c,)
# Partial RoPE on the last rope_dim dims using the block's end position
if cos_sin_cache is not None and self.rope > 0 and not for_indexer:
pos_t = torch.tensor([block_end_pos], dtype=torch.long,
device=cur_hidden.device)
C_out = _apply_partial_rope(
C_out.unsqueeze(0), pos_t, cos_sin_cache,
self.nope, self.rope
).squeeze(0)
return C_out # (c,) or (c_I,)
# ----------------------------------------------------------------
# Prefill: all n tokens at once
# ----------------------------------------------------------------
def prefill(
self,
hidden: torch.Tensor, # (n, d)
cos_sin_cache: Optional[torch.Tensor], # (max_pos, rope_dim)
start_pos: int = 0,
state: Optional[CompressorState] = None,
) -> CompressorState:
"""
Process all n tokens in one shot (prefill / context ingestion).
Tokens that don't fill a complete block of m are stored in
state.tail_hidden for future incremental decode steps.
Returns an updated CompressorState.
"""
if state is None:
state = CompressorState()
n, d = hidden.shape
m = self.m
# If there are tail tokens from a previous call, prepend them
if state.tail_hidden is not None and state.tail_hidden.shape[0] > 0:
hidden = torch.cat([state.tail_hidden, hidden], dim=0)
# The positions need to be adjusted accordingly
n = hidden.shape[0]
n_complete_blocks = n // m
n_tail = n % m
kv_list = []
indexer_kv_list = []
prev_hidden = state.prev_hidden # None on first ever call
for i in range(n_complete_blocks):
cur = hidden[i * m : (i + 1) * m] # (m, d)
# Absolute position of the last token in this block
block_end = start_pos + i * m + (m - 1)
c_kv = self._compress_block(
cur, prev_hidden, cos_sin_cache, block_end, for_indexer=False
)
c_I = self._compress_block(
cur, prev_hidden, None, block_end, for_indexer=True
)
kv_list.append(c_kv)
indexer_kv_list.append(c_I)
prev_hidden = cur # this block becomes the "previous" for the next
# Accumulate into state
new_kv = torch.stack(kv_list, dim=0) if kv_list else None # (n_blocks, c)
new_I = torch.stack(indexer_kv_list, dim=0) if indexer_kv_list else None
if state.compressed_kv is None:
state.compressed_kv = new_kv
state.compressed_indexer_kv = new_I
elif new_kv is not None:
state.compressed_kv = torch.cat([state.compressed_kv, new_kv], dim=0)
state.compressed_indexer_kv = torch.cat([state.compressed_indexer_kv, new_I], dim=0)
state.num_blocks += n_complete_blocks
state.prev_hidden = prev_hidden
state.tail_hidden = hidden[n_complete_blocks * m :] if n_tail > 0 else None
return state
# ----------------------------------------------------------------
# Decode: single new token, incremental
# ----------------------------------------------------------------
def decode_step(
self,
hidden_new: torch.Tensor, # (1, d) or (d,)
cos_sin_cache: Optional[torch.Tensor],
current_pos: int,
state: CompressorState,
) -> tuple[CompressorState, bool]:
"""
Ingest one new token into the state.
Returns (updated_state, new_block_committed).
new_block_committed is True when the new token completes a block
and a new compressed entry has been appended to state.compressed_kv.
The caller only needs to re-run the Lightning Indexer when True.
"""
h = hidden_new.reshape(1, self.d)
# Append to tail
if state.tail_hidden is None:
state.tail_hidden = h
else:
state.tail_hidden = torch.cat([state.tail_hidden, h], dim=0)
tail_len = state.tail_hidden.shape[0]
if tail_len < self.m:
# Block not yet complete — nothing to compress
return state, False
# Tail is exactly m tokens — compress
assert tail_len == self.m, f"tail_len={tail_len} should equal m={self.m}"
cur = state.tail_hidden # (m, d)
block_end = current_pos # last token in block
c_kv = self._compress_block(
cur, state.prev_hidden, cos_sin_cache, block_end, for_indexer=False
)
c_I = self._compress_block(
cur, state.prev_hidden, None, block_end, for_indexer=True
)
# Append to accumulated KV
c_kv_2d = c_kv.unsqueeze(0) # (1, c)
c_I_2d = c_I.unsqueeze(0) # (1, c_I)
if state.compressed_kv is None:
state.compressed_kv = c_kv_2d
state.compressed_indexer_kv = c_I_2d
else:
state.compressed_kv = torch.cat([state.compressed_kv, c_kv_2d], dim=0)
state.compressed_indexer_kv = torch.cat([state.compressed_indexer_kv, c_I_2d], dim=0)
state.num_blocks += 1
state.prev_hidden = cur # save for next block's overlap
state.tail_hidden = None # clear tail
return state, True
# ---------------------------------------------------------------------------
# HCA compressor
# ---------------------------------------------------------------------------
class HCACompressor:
"""
Heavily Compressed Attention token-level compressor.
Paper Section 2.3.2. Simpler than CSA:
- No overlap: each block of m' tokens is self-contained.
- No indexer: HCA uses dense MQA over all compressed entries,
so no top-k selection is needed and there are no indexer keys.
For compressed block i:
C = H · W_KV (n, c)
Z = H · W_Z (n, c)
S_i = softmax_row(Z[m'i:m'(i+1)] + B) (m', c)
C^Comp_i = (S_i * C[m'i:m'(i+1)]).sum(0) (c,)
"""
def __init__(
self,
hidden_dim: int = 7168,
head_dim: int = 512,
compress_ratio: int = 128, # m'
nope_dim: int = 448,
rope_dim: int = 64,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.d = hidden_dim
self.c = head_dim
self.m = compress_ratio # m' in the paper
self.nope = nope_dim
self.rope = rope_dim
self.device = device
self.dtype = dtype
# W_KV: (d, c) W_Z: (d, c)
self.W_KV = torch.empty(hidden_dim, head_dim, dtype=dtype, device=device)
self.W_Z = torch.empty(hidden_dim, head_dim, dtype=dtype, device=device)
# Positional bias B: (m', c)
self.B = torch.empty(compress_ratio, head_dim, dtype=dtype, device=device)
def load_weights(self, W_KV, W_Z, B):
def _cvt(t): return t.to(device=self.device, dtype=self.dtype)
self.W_KV = _cvt(W_KV)
self.W_Z = _cvt(W_Z)
self.B = _cvt(B)
def _compress_block(
self,
block_hidden: torch.Tensor, # (m', d)
cos_sin_cache: Optional[torch.Tensor],
block_end_pos: int,
) -> torch.Tensor:
"""Compress one block of m' tokens → (c,)."""
m = self.m
assert block_hidden.shape[0] == m
C = _proj(block_hidden, self.W_KV) # (m', c)
Z = _proj(block_hidden, self.W_Z) # (m', c)
S = F.softmax((Z + self.B).float(), dim=0).to(self.dtype) # (m', c)
C_out = (S * C).sum(dim=0) # (c,)
if cos_sin_cache is not None and self.rope > 0:
pos_t = torch.tensor([block_end_pos], dtype=torch.long,
device=block_hidden.device)
C_out = _apply_partial_rope(
C_out.unsqueeze(0), pos_t, cos_sin_cache,
self.nope, self.rope
).squeeze(0)
return C_out
def prefill(
self,
hidden: torch.Tensor, # (n, d)
cos_sin_cache: Optional[torch.Tensor],
start_pos: int = 0,
state: Optional[CompressorState] = None,
) -> CompressorState:
"""Process all n tokens (prefill). Tail stored for later decode."""
if state is None:
state = CompressorState()
n = hidden.shape[0]
if state.tail_hidden is not None and state.tail_hidden.shape[0] > 0:
hidden = torch.cat([state.tail_hidden, hidden], dim=0)
n = hidden.shape[0]
m = self.m
n_complete = n // m
n_tail = n % m
kv_list = []
for i in range(n_complete):
block = hidden[i * m : (i + 1) * m]
block_end = start_pos + i * m + (m - 1)
kv_list.append(self._compress_block(block, cos_sin_cache, block_end))
new_kv = torch.stack(kv_list, dim=0) if kv_list else None
if state.compressed_kv is None:
state.compressed_kv = new_kv
elif new_kv is not None:
state.compressed_kv = torch.cat([state.compressed_kv, new_kv], dim=0)
state.num_blocks += n_complete
state.tail_hidden = hidden[n_complete * m :] if n_tail > 0 else None
# HCA has no prev_hidden needed (no overlap)
return state
def decode_step(
self,
hidden_new: torch.Tensor, # (1, d)
cos_sin_cache: Optional[torch.Tensor],
current_pos: int,
state: CompressorState,
) -> tuple[CompressorState, bool]:
"""Ingest one token. Returns (state, new_block_committed)."""
h = hidden_new.reshape(1, self.d)
state.tail_hidden = h if state.tail_hidden is None else \
torch.cat([state.tail_hidden, h], dim=0)
if state.tail_hidden.shape[0] < self.m:
return state, False
# Full block ready
block = state.tail_hidden # (m', d)
c_kv = self._compress_block(block, cos_sin_cache, current_pos)
c_kv_2d = c_kv.unsqueeze(0)
state.compressed_kv = c_kv_2d if state.compressed_kv is None else \
torch.cat([state.compressed_kv, c_kv_2d], dim=0)
state.num_blocks += 1
state.tail_hidden = None
return state, True
# ---------------------------------------------------------------------------
# Quick smoke test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# ── V4-Pro dims ─────────────────────────────────────────────────
D, C, M, C_I = 7168, 512, 4, 128
M_HCA = 128
NOPE, ROPE = 448, 64
MAX_POS = 4096
cos_sin_cache = torch.randn(MAX_POS, ROPE, dtype=dtype, device=device)
# ── CSA ─────────────────────────────────────────────────────────
print("=== CSA compressor ===")
csa = CSACompressor(D, C, M, C_I, num_indexer_heads=64,
nope_dim=NOPE, rope_dim=ROPE, device=device, dtype=dtype)
# Prefill 20 tokens (5 complete blocks, 0 tail)
n_prefill = 20
h_prefill = torch.randn(n_prefill, D, dtype=dtype, device=device)
state = csa.prefill(h_prefill, cos_sin_cache, start_pos=0)
print(f" After prefill {n_prefill} tokens:")
print(f" num_blocks: {state.num_blocks}") # 5
print(f" compressed_kv: {state.compressed_kv.shape}") # (5, 512)
print(f" indexer_kv: {state.compressed_indexer_kv.shape}") # (5, 128)
print(f" tail_len: {0 if state.tail_hidden is None else state.tail_hidden.shape[0]}")
# Prefill 6 more (1 complete block + 2 tail)
h2 = torch.randn(6, D, dtype=dtype, device=device)
state = csa.prefill(h2, cos_sin_cache, start_pos=n_prefill, state=state)
print(f"\n After prefill 6 more:")
print(f" num_blocks: {state.num_blocks}") # 6
print(f" compressed_kv: {state.compressed_kv.shape}") # (6, 512)
print(f" tail_len: {state.tail_hidden.shape[0]}") # 2
# Decode 2 tokens (fills tail → 1 new block)
for tok_i in range(2):
h_tok = torch.randn(1, D, dtype=dtype, device=device)
pos = n_prefill + 6 + tok_i
state, committed = csa.decode_step(h_tok, cos_sin_cache, pos, state)
print(f" decode tok {tok_i}: committed={committed}")
print(f" num_blocks now: {state.num_blocks}") # 7
# ── HCA ─────────────────────────────────────────────────────────
print("\n=== HCA compressor ===")
hca = HCACompressor(D, C, M_HCA, nope_dim=NOPE, rope_dim=ROPE,
device=device, dtype=dtype)
n_hca = 384 # 3 complete blocks + 0 tail
h_hca = torch.randn(n_hca, D, dtype=dtype, device=device)
hca_state = hca.prefill(h_hca, cos_sin_cache, start_pos=0)
print(f" After prefill {n_hca} tokens:")
print(f" num_blocks: {hca_state.num_blocks}") # 3
print(f" compressed_kv: {hca_state.compressed_kv.shape}") # (3, 512)
# Decode 128 tokens → exactly 1 new HCA block
for tok_i in range(M_HCA):
h_tok = torch.randn(1, D, dtype=dtype, device=device)
pos = n_hca + tok_i
hca_state, committed = hca.decode_step(h_tok, cos_sin_cache, pos, hca_state)
print(f" After decode {M_HCA} tokens:")
print(f" num_blocks: {hca_state.num_blocks}") # 4
print(f" compressed_kv: {hca_state.compressed_kv.shape}") # (4, 512)
# ── Correctness sanity: prefill == incremental decode ───────────
print("\n=== Equivalence check: prefill vs incremental decode ===")
csa2 = CSACompressor(D, C, M, C_I, nope_dim=NOPE, rope_dim=ROPE,
device=device, dtype=dtype)
# Copy weights
for attr in ("W_a_KV","W_b_KV","W_a_Z","W_b_Z","B_a","B_b",
"W_I_a_KV","W_I_b_KV","W_I_a_Z","W_I_b_Z","B_I_a","B_I_b"):
setattr(csa2, attr, getattr(csa, attr))
n_check = 8
h_check = torch.randn(n_check, D, dtype=dtype, device=device)
# Batch prefill
s_batch = csa2.prefill(h_check, cos_sin_cache, start_pos=0)
# Token-by-token decode
s_incr = CompressorState()
for i in range(n_check):
s_incr, _ = csa2.decode_step(h_check[i:i+1], cos_sin_cache, i, s_incr)
if s_batch.compressed_kv is not None and s_incr.compressed_kv is not None:
max_diff = (s_batch.compressed_kv - s_incr.compressed_kv).abs().max().item()
print(f" max |prefill - decode| on compressed_kv: {max_diff:.6f}")
assert max_diff < 1e-3, "Mismatch between prefill and decode paths!"
print(" PASSED")
else:
print(" (no complete blocks produced in 8 tokens with m=4 — increase n_check)")
print("\nAll checks done.")