Files
nvfp4-megamoe-kernel/dsv4/reference/compressor.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

652 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
CSA / HCA Token-Level Compressor for DeepSeek-V4.
Implements Section 2.3 of the DeepSeek-V4 paper exactly:
- CSA (m=4): overlapping weighted sum over 2m hidden states per block
- HCA (m'=128): non-overlapping weighted sum over m' hidden states per block
Both produce compressed KV entries C^Comp ∈ R^{n/m × c} where each entry
is a weighted sum of hidden states using softmax-normalised gate weights.
CSA additionally produces compressed indexer keys K^IComp ∈ R^{n/m × c_I}
for the Lightning Indexer (top-k sparse selection).
V4-Pro reference dimensions (Section 4.2.1):
d = 7168 hidden dim
c = 512 head dim (CSA and HCA both)
m = 4 CSA compression ratio
m' = 128 HCA compression ratio
c_I = 128 indexer head dim
n_I_h = 64 num indexer query heads
n_win = 128 sliding window size (separate, not handled here)
rope_dim = 64 partial RoPE on last 64 dims of each head
Design notes
------------
* BF16 matmuls throughout — swap the _proj() calls for your NVFP4 GEMMs.
* No batch dimension: one sequence at a time (matching decode latency path).
* CompressorState carries the incomplete-block tail and the previous block's
raw projections needed for the CSA overlap.
* Partial RoPE is applied to the last rope_dim=64 dims of C^Comp before
the entry is stored in the compressed KV cache, using the representative
position = last token of the block. The inverse-RoPE on attention outputs
is handled by your attention kernel (already in blackwell_attention.py).
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
@dataclass
class CompressorState:
"""
Per-sequence mutable state for the compressor.
tail_hidden — hidden states that have arrived but don't yet fill a
complete compression block. Shape (tail_len, d),
0 <= tail_len < m.
prev_hidden — the m hidden states from the previous complete block.
Needed for the CSA overlap (C^b / Z^b projections).
None before the first block is committed.
Not used by HCA (no overlap).
compressed_kv — accumulated C^Comp entries, shape (n_blocks, c).
compressed_indexer_kv — accumulated K^IComp entries, shape (n_blocks, c_I).
None for HCA layers.
"""
tail_hidden: Optional[torch.Tensor] = None # (tail_len, d)
prev_hidden: Optional[torch.Tensor] = None # (m, d) CSA only
compressed_kv: Optional[torch.Tensor] = None # (n_blocks, c)
compressed_indexer_kv: Optional[torch.Tensor] = None # (n_blocks, c_I)
num_blocks: int = 0
def reset(self):
self.tail_hidden = None
self.prev_hidden = None
self.compressed_kv = None
self.compressed_indexer_kv = None
self.num_blocks = 0
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _apply_partial_rope(
x: torch.Tensor, # (..., c)
positions: torch.Tensor, # (...,) int64
cos_sin_cache: torch.Tensor, # (max_pos, rope_dim)
nope_dim: int,
rope_dim: int,
) -> torch.Tensor:
"""GPT-J style RoPE on the last rope_dim dimensions only."""
if rope_dim == 0:
return x
half = rope_dim // 2
cos = cos_sin_cache[positions, :half].to(x.dtype) # (..., half)
sin = cos_sin_cache[positions, half:].to(x.dtype) # (..., half)
out = x.clone()
rope_part = out[..., nope_dim:] # (..., rope_dim)
even = rope_part[..., 0::2]
odd = rope_part[..., 1::2]
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
return out
def _proj(x: torch.Tensor, W: torch.Tensor) -> torch.Tensor:
"""
Linear projection: x @ W.
x: (..., d) W: (d, out_dim) → (..., out_dim)
*** SWAP THIS FOR YOUR NVFP4 GEMM. ***
The W tensor would become your NVFP4 weight + scale_b + gsb triple,
and x would be quantised to NVFP4 activation before the call.
"""
return x.to(W.dtype) @ W
# ---------------------------------------------------------------------------
# CSA compressor
# ---------------------------------------------------------------------------
class CSACompressor:
"""
Compressed Sparse Attention token-level compressor.
Paper equations (11) and (12), Section 2.3.1:
C^a = H · W^a_KV Z^a = H · W^a_Z (current block)
C^b = H · W^b_KV Z^b = H · W^b_Z (prev-block overlap)
For compressed block i (0-indexed):
[S^a ; S^b] = softmax_row( [Z^a_{cur} + B^a ; Z^b_{prev} + B^b] )
C^Comp_i = Σ S^a_j ⊙ C^a_j + Σ S^b_j ⊙ C^b_j
When i=0: Z^b / C^b are padded with -inf / 0 so only Z^a contributes.
The same compression is applied independently to produce indexer keys,
using separate projections W^I_KV, W^I_Z, B^I_a, B^I_b.
"""
def __init__(
self,
hidden_dim: int = 7168, # d
head_dim: int = 512, # c
compress_ratio: int = 4, # m
indexer_head_dim: int = 128, # c_I
num_indexer_heads: int = 64, # n_I_h (not used in compressor itself)
nope_dim: int = 448, # c - rope_dim = 512 - 64
rope_dim: int = 64,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.d = hidden_dim
self.c = head_dim
self.m = compress_ratio
self.c_I = indexer_head_dim
self.n_I_h = num_indexer_heads
self.nope = nope_dim
self.rope = rope_dim
self.device = device
self.dtype = dtype
# ── Main KV projection weights ──────────────────────────────
# W^a_{KV}, W^b_{KV}: (d, c)
# W^a_Z, W^b_Z: (d, c)
self.W_a_KV = self._param(hidden_dim, head_dim)
self.W_b_KV = self._param(hidden_dim, head_dim)
self.W_a_Z = self._param(hidden_dim, head_dim)
self.W_b_Z = self._param(hidden_dim, head_dim)
# Positional biases B^a, B^b: (m, c) — learnable per-position offsets
# added to the gate logits before softmax.
self.B_a = self._param(compress_ratio, head_dim)
self.B_b = self._param(compress_ratio, head_dim)
# ── Indexer key projection weights ──────────────────────────
# Same overlap structure, separate projections, output dim c_I.
self.W_I_a_KV = self._param(hidden_dim, indexer_head_dim)
self.W_I_b_KV = self._param(hidden_dim, indexer_head_dim)
self.W_I_a_Z = self._param(hidden_dim, indexer_head_dim)
self.W_I_b_Z = self._param(hidden_dim, indexer_head_dim)
self.B_I_a = self._param(compress_ratio, indexer_head_dim)
self.B_I_b = self._param(compress_ratio, indexer_head_dim)
def _param(self, *shape) -> torch.Tensor:
"""Uninitialised placeholder — replace with checkpoint-loaded tensor."""
return torch.empty(*shape, dtype=self.dtype, device=self.device)
def load_weights(
self,
W_a_KV, W_b_KV, W_a_Z, W_b_Z, B_a, B_b,
W_I_a_KV, W_I_b_KV, W_I_a_Z, W_I_b_Z, B_I_a, B_I_b,
):
"""Assign weights from checkpoint. All tensors moved to device/dtype."""
def _cvt(t): return t.to(device=self.device, dtype=self.dtype)
self.W_a_KV = _cvt(W_a_KV); self.W_b_KV = _cvt(W_b_KV)
self.W_a_Z = _cvt(W_a_Z); self.W_b_Z = _cvt(W_b_Z)
self.B_a = _cvt(B_a); self.B_b = _cvt(B_b)
self.W_I_a_KV = _cvt(W_I_a_KV); self.W_I_b_KV = _cvt(W_I_b_KV)
self.W_I_a_Z = _cvt(W_I_a_Z); self.W_I_b_Z = _cvt(W_I_b_Z)
self.B_I_a = _cvt(B_I_a); self.B_I_b = _cvt(B_I_b)
# ----------------------------------------------------------------
# Core: compress one block of m hidden states
# ----------------------------------------------------------------
def _compress_block(
self,
cur_hidden: torch.Tensor, # (m, d) current block
prev_hidden: Optional[torch.Tensor], # (m, d) or None if block 0
cos_sin_cache: Optional[torch.Tensor],
block_end_pos: int, # position of last token in block
for_indexer: bool = False,
) -> torch.Tensor:
"""
Compress one block into a single C^Comp entry.
Returns: (c,) or (c_I,) compressed entry.
The overlap computation (equations 11-12):
Z_cat = [Z^a + B^a ; Z^b + B^b] shape (2m, c) or (m, c) when block 0
S = softmax(Z_cat, dim=0) normalise over the 2m position axis
C_out = (S[:m] * C^a).sum(0) + (S[m:] * C^b).sum(0)
"""
m = self.m
assert cur_hidden.shape[0] == m
if for_indexer:
W_a_KV, W_b_KV = self.W_I_a_KV, self.W_I_b_KV
W_a_Z, W_b_Z = self.W_I_a_Z, self.W_I_b_Z
B_a, B_b = self.B_I_a, self.B_I_b
else:
W_a_KV, W_b_KV = self.W_a_KV, self.W_b_KV
W_a_Z, W_b_Z = self.W_a_Z, self.W_b_Z
B_a, B_b = self.B_a, self.B_b
# Current block projections: (m, c)
C_a = _proj(cur_hidden, W_a_KV) # KV candidates
Z_a = _proj(cur_hidden, W_a_Z) # gate logits
if prev_hidden is None:
# Block 0: no previous block → softmax over m entries only
# (paper pads Z^b with -inf, C^b with 0)
Z_cat = Z_a + B_a # (m, c)
S = F.softmax(Z_cat.float(), dim=0).to(self.dtype) # (m, c)
C_out = (S * C_a).sum(dim=0) # (c,)
else:
# Blocks 1..: overlap with previous block
C_b = _proj(prev_hidden, W_b_KV) # (m, c)
Z_b = _proj(prev_hidden, W_b_Z) # (m, c)
# Concatenate along the position axis → (2m, c)
Z_cat = torch.cat([Z_a + B_a, Z_b + B_b], dim=0)
S = F.softmax(Z_cat.float(), dim=0).to(self.dtype) # (2m, c)
S_a, S_b = S[:m], S[m:] # each (m, c)
C_out = (S_a * C_a).sum(dim=0) + (S_b * C_b).sum(dim=0) # (c,)
# Partial RoPE on the last rope_dim dims using the block's end position
if cos_sin_cache is not None and self.rope > 0 and not for_indexer:
pos_t = torch.tensor([block_end_pos], dtype=torch.long,
device=cur_hidden.device)
C_out = _apply_partial_rope(
C_out.unsqueeze(0), pos_t, cos_sin_cache,
self.nope, self.rope
).squeeze(0)
return C_out # (c,) or (c_I,)
# ----------------------------------------------------------------
# Prefill: all n tokens at once
# ----------------------------------------------------------------
def prefill(
self,
hidden: torch.Tensor, # (n, d)
cos_sin_cache: Optional[torch.Tensor], # (max_pos, rope_dim)
start_pos: int = 0,
state: Optional[CompressorState] = None,
) -> CompressorState:
"""
Process all n tokens in one shot (prefill / context ingestion).
Tokens that don't fill a complete block of m are stored in
state.tail_hidden for future incremental decode steps.
Returns an updated CompressorState.
"""
if state is None:
state = CompressorState()
n, d = hidden.shape
m = self.m
# If there are tail tokens from a previous call, prepend them
if state.tail_hidden is not None and state.tail_hidden.shape[0] > 0:
hidden = torch.cat([state.tail_hidden, hidden], dim=0)
# The positions need to be adjusted accordingly
n = hidden.shape[0]
n_complete_blocks = n // m
n_tail = n % m
kv_list = []
indexer_kv_list = []
prev_hidden = state.prev_hidden # None on first ever call
for i in range(n_complete_blocks):
cur = hidden[i * m : (i + 1) * m] # (m, d)
# Absolute position of the last token in this block
block_end = start_pos + i * m + (m - 1)
c_kv = self._compress_block(
cur, prev_hidden, cos_sin_cache, block_end, for_indexer=False
)
c_I = self._compress_block(
cur, prev_hidden, None, block_end, for_indexer=True
)
kv_list.append(c_kv)
indexer_kv_list.append(c_I)
prev_hidden = cur # this block becomes the "previous" for the next
# Accumulate into state
new_kv = torch.stack(kv_list, dim=0) if kv_list else None # (n_blocks, c)
new_I = torch.stack(indexer_kv_list, dim=0) if indexer_kv_list else None
if state.compressed_kv is None:
state.compressed_kv = new_kv
state.compressed_indexer_kv = new_I
elif new_kv is not None:
state.compressed_kv = torch.cat([state.compressed_kv, new_kv], dim=0)
state.compressed_indexer_kv = torch.cat([state.compressed_indexer_kv, new_I], dim=0)
state.num_blocks += n_complete_blocks
state.prev_hidden = prev_hidden
state.tail_hidden = hidden[n_complete_blocks * m :] if n_tail > 0 else None
return state
# ----------------------------------------------------------------
# Decode: single new token, incremental
# ----------------------------------------------------------------
def decode_step(
self,
hidden_new: torch.Tensor, # (1, d) or (d,)
cos_sin_cache: Optional[torch.Tensor],
current_pos: int,
state: CompressorState,
) -> tuple[CompressorState, bool]:
"""
Ingest one new token into the state.
Returns (updated_state, new_block_committed).
new_block_committed is True when the new token completes a block
and a new compressed entry has been appended to state.compressed_kv.
The caller only needs to re-run the Lightning Indexer when True.
"""
h = hidden_new.reshape(1, self.d)
# Append to tail
if state.tail_hidden is None:
state.tail_hidden = h
else:
state.tail_hidden = torch.cat([state.tail_hidden, h], dim=0)
tail_len = state.tail_hidden.shape[0]
if tail_len < self.m:
# Block not yet complete — nothing to compress
return state, False
# Tail is exactly m tokens — compress
assert tail_len == self.m, f"tail_len={tail_len} should equal m={self.m}"
cur = state.tail_hidden # (m, d)
block_end = current_pos # last token in block
c_kv = self._compress_block(
cur, state.prev_hidden, cos_sin_cache, block_end, for_indexer=False
)
c_I = self._compress_block(
cur, state.prev_hidden, None, block_end, for_indexer=True
)
# Append to accumulated KV
c_kv_2d = c_kv.unsqueeze(0) # (1, c)
c_I_2d = c_I.unsqueeze(0) # (1, c_I)
if state.compressed_kv is None:
state.compressed_kv = c_kv_2d
state.compressed_indexer_kv = c_I_2d
else:
state.compressed_kv = torch.cat([state.compressed_kv, c_kv_2d], dim=0)
state.compressed_indexer_kv = torch.cat([state.compressed_indexer_kv, c_I_2d], dim=0)
state.num_blocks += 1
state.prev_hidden = cur # save for next block's overlap
state.tail_hidden = None # clear tail
return state, True
# ---------------------------------------------------------------------------
# HCA compressor
# ---------------------------------------------------------------------------
class HCACompressor:
"""
Heavily Compressed Attention token-level compressor.
Paper Section 2.3.2. Simpler than CSA:
- No overlap: each block of m' tokens is self-contained.
- No indexer: HCA uses dense MQA over all compressed entries,
so no top-k selection is needed and there are no indexer keys.
For compressed block i:
C = H · W_KV (n, c)
Z = H · W_Z (n, c)
S_i = softmax_row(Z[m'i:m'(i+1)] + B) (m', c)
C^Comp_i = (S_i * C[m'i:m'(i+1)]).sum(0) (c,)
"""
def __init__(
self,
hidden_dim: int = 7168,
head_dim: int = 512,
compress_ratio: int = 128, # m'
nope_dim: int = 448,
rope_dim: int = 64,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.d = hidden_dim
self.c = head_dim
self.m = compress_ratio # m' in the paper
self.nope = nope_dim
self.rope = rope_dim
self.device = device
self.dtype = dtype
# W_KV: (d, c) W_Z: (d, c)
self.W_KV = torch.empty(hidden_dim, head_dim, dtype=dtype, device=device)
self.W_Z = torch.empty(hidden_dim, head_dim, dtype=dtype, device=device)
# Positional bias B: (m', c)
self.B = torch.empty(compress_ratio, head_dim, dtype=dtype, device=device)
def load_weights(self, W_KV, W_Z, B):
def _cvt(t): return t.to(device=self.device, dtype=self.dtype)
self.W_KV = _cvt(W_KV)
self.W_Z = _cvt(W_Z)
self.B = _cvt(B)
def _compress_block(
self,
block_hidden: torch.Tensor, # (m', d)
cos_sin_cache: Optional[torch.Tensor],
block_end_pos: int,
) -> torch.Tensor:
"""Compress one block of m' tokens → (c,)."""
m = self.m
assert block_hidden.shape[0] == m
C = _proj(block_hidden, self.W_KV) # (m', c)
Z = _proj(block_hidden, self.W_Z) # (m', c)
S = F.softmax((Z + self.B).float(), dim=0).to(self.dtype) # (m', c)
C_out = (S * C).sum(dim=0) # (c,)
if cos_sin_cache is not None and self.rope > 0:
pos_t = torch.tensor([block_end_pos], dtype=torch.long,
device=block_hidden.device)
C_out = _apply_partial_rope(
C_out.unsqueeze(0), pos_t, cos_sin_cache,
self.nope, self.rope
).squeeze(0)
return C_out
def prefill(
self,
hidden: torch.Tensor, # (n, d)
cos_sin_cache: Optional[torch.Tensor],
start_pos: int = 0,
state: Optional[CompressorState] = None,
) -> CompressorState:
"""Process all n tokens (prefill). Tail stored for later decode."""
if state is None:
state = CompressorState()
n = hidden.shape[0]
if state.tail_hidden is not None and state.tail_hidden.shape[0] > 0:
hidden = torch.cat([state.tail_hidden, hidden], dim=0)
n = hidden.shape[0]
m = self.m
n_complete = n // m
n_tail = n % m
kv_list = []
for i in range(n_complete):
block = hidden[i * m : (i + 1) * m]
block_end = start_pos + i * m + (m - 1)
kv_list.append(self._compress_block(block, cos_sin_cache, block_end))
new_kv = torch.stack(kv_list, dim=0) if kv_list else None
if state.compressed_kv is None:
state.compressed_kv = new_kv
elif new_kv is not None:
state.compressed_kv = torch.cat([state.compressed_kv, new_kv], dim=0)
state.num_blocks += n_complete
state.tail_hidden = hidden[n_complete * m :] if n_tail > 0 else None
# HCA has no prev_hidden needed (no overlap)
return state
def decode_step(
self,
hidden_new: torch.Tensor, # (1, d)
cos_sin_cache: Optional[torch.Tensor],
current_pos: int,
state: CompressorState,
) -> tuple[CompressorState, bool]:
"""Ingest one token. Returns (state, new_block_committed)."""
h = hidden_new.reshape(1, self.d)
state.tail_hidden = h if state.tail_hidden is None else \
torch.cat([state.tail_hidden, h], dim=0)
if state.tail_hidden.shape[0] < self.m:
return state, False
# Full block ready
block = state.tail_hidden # (m', d)
c_kv = self._compress_block(block, cos_sin_cache, current_pos)
c_kv_2d = c_kv.unsqueeze(0)
state.compressed_kv = c_kv_2d if state.compressed_kv is None else \
torch.cat([state.compressed_kv, c_kv_2d], dim=0)
state.num_blocks += 1
state.tail_hidden = None
return state, True
# ---------------------------------------------------------------------------
# Quick smoke test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# ── V4-Pro dims ─────────────────────────────────────────────────
D, C, M, C_I = 7168, 512, 4, 128
M_HCA = 128
NOPE, ROPE = 448, 64
MAX_POS = 4096
cos_sin_cache = torch.randn(MAX_POS, ROPE, dtype=dtype, device=device)
# ── CSA ─────────────────────────────────────────────────────────
print("=== CSA compressor ===")
csa = CSACompressor(D, C, M, C_I, num_indexer_heads=64,
nope_dim=NOPE, rope_dim=ROPE, device=device, dtype=dtype)
# Prefill 20 tokens (5 complete blocks, 0 tail)
n_prefill = 20
h_prefill = torch.randn(n_prefill, D, dtype=dtype, device=device)
state = csa.prefill(h_prefill, cos_sin_cache, start_pos=0)
print(f" After prefill {n_prefill} tokens:")
print(f" num_blocks: {state.num_blocks}") # 5
print(f" compressed_kv: {state.compressed_kv.shape}") # (5, 512)
print(f" indexer_kv: {state.compressed_indexer_kv.shape}") # (5, 128)
print(f" tail_len: {0 if state.tail_hidden is None else state.tail_hidden.shape[0]}")
# Prefill 6 more (1 complete block + 2 tail)
h2 = torch.randn(6, D, dtype=dtype, device=device)
state = csa.prefill(h2, cos_sin_cache, start_pos=n_prefill, state=state)
print(f"\n After prefill 6 more:")
print(f" num_blocks: {state.num_blocks}") # 6
print(f" compressed_kv: {state.compressed_kv.shape}") # (6, 512)
print(f" tail_len: {state.tail_hidden.shape[0]}") # 2
# Decode 2 tokens (fills tail → 1 new block)
for tok_i in range(2):
h_tok = torch.randn(1, D, dtype=dtype, device=device)
pos = n_prefill + 6 + tok_i
state, committed = csa.decode_step(h_tok, cos_sin_cache, pos, state)
print(f" decode tok {tok_i}: committed={committed}")
print(f" num_blocks now: {state.num_blocks}") # 7
# ── HCA ─────────────────────────────────────────────────────────
print("\n=== HCA compressor ===")
hca = HCACompressor(D, C, M_HCA, nope_dim=NOPE, rope_dim=ROPE,
device=device, dtype=dtype)
n_hca = 384 # 3 complete blocks + 0 tail
h_hca = torch.randn(n_hca, D, dtype=dtype, device=device)
hca_state = hca.prefill(h_hca, cos_sin_cache, start_pos=0)
print(f" After prefill {n_hca} tokens:")
print(f" num_blocks: {hca_state.num_blocks}") # 3
print(f" compressed_kv: {hca_state.compressed_kv.shape}") # (3, 512)
# Decode 128 tokens → exactly 1 new HCA block
for tok_i in range(M_HCA):
h_tok = torch.randn(1, D, dtype=dtype, device=device)
pos = n_hca + tok_i
hca_state, committed = hca.decode_step(h_tok, cos_sin_cache, pos, hca_state)
print(f" After decode {M_HCA} tokens:")
print(f" num_blocks: {hca_state.num_blocks}") # 4
print(f" compressed_kv: {hca_state.compressed_kv.shape}") # (4, 512)
# ── Correctness sanity: prefill == incremental decode ───────────
print("\n=== Equivalence check: prefill vs incremental decode ===")
csa2 = CSACompressor(D, C, M, C_I, nope_dim=NOPE, rope_dim=ROPE,
device=device, dtype=dtype)
# Copy weights
for attr in ("W_a_KV","W_b_KV","W_a_Z","W_b_Z","B_a","B_b",
"W_I_a_KV","W_I_b_KV","W_I_a_Z","W_I_b_Z","B_I_a","B_I_b"):
setattr(csa2, attr, getattr(csa, attr))
n_check = 8
h_check = torch.randn(n_check, D, dtype=dtype, device=device)
# Batch prefill
s_batch = csa2.prefill(h_check, cos_sin_cache, start_pos=0)
# Token-by-token decode
s_incr = CompressorState()
for i in range(n_check):
s_incr, _ = csa2.decode_step(h_check[i:i+1], cos_sin_cache, i, s_incr)
if s_batch.compressed_kv is not None and s_incr.compressed_kv is not None:
max_diff = (s_batch.compressed_kv - s_incr.compressed_kv).abs().max().item()
print(f" max |prefill - decode| on compressed_kv: {max_diff:.6f}")
assert max_diff < 1e-3, "Mismatch between prefill and decode paths!"
print(" PASSED")
else:
print(" (no complete blocks produced in 8 tokens with m=4 — increase n_check)")
print("\nAll checks done.")