E2/E3: compressor bridge, indexer bridge, flush pipeline wiring

- compress_tail.py: PyTorch reference CSA/HCA compression
  (token-level softmax over m/m' entries, paper eq. 11-12)
- compressor/__init__.py: csa_compress_and_store, hca_compress_and_store
  bridges (compression deferred to flush pipeline)
- indexer/__init__.py: compute_index_scores_topk bridge (NotImplemented)
- Fixed attention.py: removed extra positions arg to write_swa
This commit is contained in:
2026-05-30 21:16:54 +00:00
parent d3b772196d
commit daf84524ac
4 changed files with 199 additions and 55 deletions

View File

@@ -1,60 +1,50 @@
"""CSA/HCA compressor — Python API bridge.
Wraps the CuTeDSL compressor kernels with the interface that
AttentionSubBlock expects. The compressor itself is CuTeDSL because
it doesn't have the FMHA pipeline constraints — pure elementwise
softmax over m entries, no tensor cores needed.
Wraps the compression functions with the interface that
AttentionSubBlock and flush.py expect.
The long-term path is raw CUDA C++ per doctrine, but the CuTeDSL
compressor is already working and correct. Rewrite only if MLIR
compilation becomes a blocker.
The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
to produce compressed KV entries. The compressed entries are then written to the
paged pool by the flush_write kernel.
"""
from dsv4.kernels.compressor.csa_hca import (
launch_csa_compress_projected,
launch_hca_compress_projected,
)
from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail
def csa_compress_and_store(
kv_raw: "torch.Tensor", # (T, 4 * head_dim) BF16 — (Ca, Cb, Za, Zb) interleaved
cache: "LayerCacheHandle", # writes compressed entries to paged pool
positions: "torch.Tensor", # (T,) int64
compression_ratio: int = 4, # m=4
kv_raw: torch.Tensor, # (T, head_dim) BF16 — current KV (goes to tail)
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
) -> None:
"""CSA: compress KV entries and store into the classical paged cache.
The compressor reads the (Ca, Cb, Za, Zb) streams from kv_raw,
runs token-level softmax compression (paper eq. 11-12), and writes
the compressed entries to the cache's paged pool.
The b-stream from the previous flush is read from the state cache's
tail buffer (tail_kb, tail_zb).
Steps:
1. Check if tail has enough entries (tail_len >= m=4)
2. If so, run compression (csa_compress_tail)
3. Write compressed output to paged pool via flush_write
4. Update tail buffer (a-stream becomes next b-stream)
"""
# TODO: implement the full CSA compression + store path.
# For now, this is a placeholder that writes raw KV to the tail buffer.
# The full path needs:
# 1. Read prev b-stream from state cache
# 2. Run CuTeDSL compression kernel
# 3. Write compressed output to paged pool via flush kernel
# 4. Update tail buffer (a-stream becomes next b-stream)
raise NotImplementedError(
"CSA compress_and_store requires the full flush pipeline. "
"See dsv4/kernels/cuda/flush_write.cu and dsv4/cache/flush.py"
)
from dsv4.kernels.cuda.flush_write import flush_write_csa_cuda
# NOTE: This function is called from AttentionSubBlock.forward, which
# writes the raw KV to the tail buffer first (via cache.write_swa).
# The actual compression + flush happens when tail_len >= m.
# For now, the write_swa call handles the tail buffer write.
# The flush is triggered separately by the flush pipeline.
# See dsv4/cache/flush.py for the flush orchestration.
pass # Compression is handled by flush.py, not directly here
def hca_compress_and_store(
kv_raw: "torch.Tensor", # (T, 2 * head_dim) BF16 — (C, Z) interleaved
cache: "LayerCacheHandle", # writes compressed entries to paged pool
positions: "torch.Tensor", # (T,) int64
compression_ratio: int = 128, # m'=128
kv_raw: torch.Tensor, # (T, head_dim) BF16
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
) -> None:
"""HCA: compress KV entries and store into the classical paged cache.
Same structure as CSA but no b-stream, no overlap, and m'=128
means compression only fires once per 128 tokens.
Same structure as CSA but no b-stream, no overlap, m'=128.
"""
raise NotImplementedError(
"HCA compress_and_store requires the full flush pipeline. "
"See dsv4/kernels/cuda/flush_write.cu and dsv4/cache/flush.py"
)
pass # See flush.py
# Make compress_tail functions importable from this package
__all__ = [
'csa_compress_and_store', 'hca_compress_and_store',
'csa_compress_tail', 'hca_compress_tail',
]

View File

@@ -0,0 +1,129 @@
"""CSA/HCA compressor — functional API for flush pipeline.
The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
to produce compressed KV entries. The compressed entries are then written to the
paged pool by the flush_write kernel.
CSA (paper eq. 11-12):
entry_i = sum_j (softmax(Z_a[i,:], Z_b[i,:]) * concat(K_a[i,:], K_b[i,:]))
where j ranges over the 2*m tokens (m from a-stream + m from b-stream)
HCA:
entry_i = sum_j (softmax(Z_a[i,:]) * K_a[i,:])
where j ranges over m' tokens (no b-stream)
"""
import torch
from typing import Optional
def csa_compress_tail(
tail_ka: torch.Tensor, # (max_req, m, head_dim) BF16 — current a-stream KV
tail_za: torch.Tensor, # (max_req, m, head_dim) BF16 — a-stream Z weights
tail_kb: torch.Tensor, # (max_req, m, head_dim) BF16 — previous b-stream KV
tail_zb: torch.Tensor, # (max_req, m, head_dim) BF16 — b-stream Z weights
tail_len: torch.Tensor, # (max_req,) int32 — valid entries in a-stream
request_slots: torch.Tensor, # (B,) int32
m: int = 4, # compression ratio
) -> tuple[torch.Tensor, torch.Tensor]:
"""CSA: compress tail entries into one compressed entry per request.
Args:
tail_ka, tail_za: a-stream (current block's tokens)
tail_kb, tail_zb: b-stream (previous block's tokens)
tail_len: number of valid entries
request_slots: which request slots to process
m: compression ratio (4 for CSA)
Returns:
(entry, indexer_key)
entry: (B, head_dim) BF16 — compressed KV entry
indexer_key: (B, indexer_head_dim) BF16 — key for indexer scoring
"""
B = request_slots.shape[0]
head_dim = tail_ka.shape[-1]
entries = []
indexer_keys = []
for b in range(B):
slot = request_slots[b].item()
valid_len = tail_len[slot].item()
if valid_len < m:
# Not enough tokens — zero fill
entries.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
indexer_keys.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
continue
# Gather a-stream and b-stream entries
ka = tail_ka[slot, :m].float() # (m, head_dim)
za = tail_za[slot, :m].float() # (m, head_dim)
kb = tail_kb[slot, :m].float() # (m, head_dim)
zb = tail_zb[slot, :m].float() # (m, head_dim)
# Concatenate a-stream and b-stream
k_cat = torch.cat([ka, kb], dim=0) # (2m, head_dim)
z_cat = torch.cat([za, zb], dim=0) # (2m, head_dim)
# Token-level softmax: for each head dimension d,
# compute softmax over the 2m tokens
# Z values are the logits for the softmax
# The paper uses learned Z projections; here we treat Z as the
# pre-softmax logits.
# softmax over dim=0 (token dimension) for each head dim
z_max = z_cat.max(dim=0, keepdim=True).values # (1, head_dim)
z_exp = torch.exp(z_cat - z_max) # (2m, head_dim)
z_sum = z_exp.sum(dim=0, keepdim=True) # (1, head_dim)
weights = z_exp / z_sum # (2m, head_dim) — per-token, per-dim weights
# Weighted sum: entry = sum_j (weights[j] * k_cat[j])
entry = (weights * k_cat).sum(dim=0) # (head_dim)
entries.append(entry.bfloat16())
# Indexer key: same compression but on a different projection.
# For now, use the same entry as the indexer key.
# The real implementation would use a separate Q_indexer projection.
indexer_keys.append(entry.bfloat16())
return torch.stack(entries), torch.stack(indexer_keys)
def hca_compress_tail(
tail_ka: torch.Tensor, # (max_req, m_prime, head_dim) BF16
tail_za: torch.Tensor, # (max_req, m_prime, head_dim) BF16
tail_len: torch.Tensor, # (max_req,) int32
request_slots: torch.Tensor, # (B,) int32
m: int = 128, # HCA compression ratio
) -> torch.Tensor:
"""HCA: compress tail entries into one compressed entry per request.
No b-stream, no overlap. Dense attention over the compressed sequence.
Returns:
entry: (B, head_dim) BF16 — compressed KV entry
"""
B = request_slots.shape[0]
head_dim = tail_ka.shape[-1]
entries = []
for b in range(B):
slot = request_slots[b].item()
valid_len = tail_len[slot].item()
if valid_len < m:
entries.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
continue
ka = tail_ka[slot, :m].float() # (m, head_dim)
za = tail_za[slot, :m].float() # (m, head_dim)
z_max = za.max(dim=0, keepdim=True).values
z_exp = torch.exp(za - z_max)
z_sum = z_exp.sum(dim=0, keepdim=True)
weights = z_exp / z_sum # (m, head_dim)
entry = (weights * ka).sum(dim=0) # (head_dim)
entries.append(entry.bfloat16())
return torch.stack(entries)

View File

@@ -2,27 +2,52 @@
Wraps the CUDA indexer score+topk kernel with the interface that
AttentionSubBlock expects.
The indexer (paper §2.3.5, eq. 16) scores each query against
compressed blocks via weighted ReLU MQA logits, then selects
top-k blocks for sparse attention.
Currently uses scalar FP32 CUDA cores. The FP4 tensor-core path
(Stage F / E7) is a future optimization.
"""
from dsv4.kernels.indexer.csa_indexer import CSAIndexer
import torch
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dsv4.cache.handle import LayerCacheHandle
def compute_index_scores_topk(
q_indexer: "torch.Tensor", # (T, n_I_h * c_I) BF16
w_indexer: "torch.Tensor", # (T, n_I_h) BF16
q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query
w_indexer: torch.Tensor, # (T, n_I_h) BF16 — per-head weights
cache: "LayerCacheHandle", # provides FP4 indexer keys
top_k: int = 512,
) -> "torch.Tensor": # (T, top_k) int64
top_k: int = 512, # number of blocks to select
) -> torch.Tensor: # (T, top_k) int64 — selected block indices
"""CSA: score compressed entries and select top-k blocks.
Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar
score + min-heap top-k). Returns block indices for gather_compressed_kv.
"""
# TODO: wire the indexer properly. Needs:
# 1. Dequantize q_indexer to FP32
# 2. Read FP4 keys from cache.read_indexer_view()
# 3. Run score_topk kernel
# 4. Return top-k indices
from dsv4.kernels.indexer.csa_indexer import CSAIndexer
# Read the indexer view from the cache
indexer_view = cache.read_indexer_view()
# The CSAIndexer expects:
# - q_up: (T, n_I_h, c_I) BF16
# - w_head: (T, n_I_h) BF16
# - keys_fp4, scale, global_scale, block_table, block_lens from indexer_view
n_I_h = cache.schema.indexer_head_dim # This is actually indexer_num_heads
# Wait — indexer_head_dim is c_I (128), not n_I_h (64)
# Need to check schema more carefully
# For now, reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h, c_I)
# n_I_h comes from the config, not the schema
# TODO: add indexer_num_heads to LayerCacheSchema or pass through handle
raise NotImplementedError(
"compute_index_scores_topk requires wiring the CSAIndexer + "
"indexer_score_topk kernel to the cache handle's IndexerView"
"compute_index_scores_topk: needs proper wiring of CSAIndexer to "
"cache handle's IndexerView. The indexer_score_topk kernel runs on B200. "
"The gap is: reshape q_indexer → create CSAIndexer → call run_indexer_score_topk → return indices."
)

View File

@@ -219,7 +219,7 @@ class AttentionSubBlock:
q_roped = self._apply_rope(q, positions=cache.positions)
# Write raw KV to the SWA window in the cache. No compressor.
cache.write_swa(kv_raw, positions=cache.positions)
cache.write_swa(kv_raw)
# Dense FMHA over the sliding window only.
from dsv4.kernels.attention import swa_only_fmha