E2/E3: compressor bridge, indexer bridge, flush pipeline wiring
- compress_tail.py: PyTorch reference CSA/HCA compression (token-level softmax over m/m' entries, paper eq. 11-12) - compressor/__init__.py: csa_compress_and_store, hca_compress_and_store bridges (compression deferred to flush pipeline) - indexer/__init__.py: compute_index_scores_topk bridge (NotImplemented) - Fixed attention.py: removed extra positions arg to write_swa
This commit is contained in:
@@ -1,60 +1,50 @@
|
||||
"""CSA/HCA compressor — Python API bridge.
|
||||
|
||||
Wraps the CuTeDSL compressor kernels with the interface that
|
||||
AttentionSubBlock expects. The compressor itself is CuTeDSL because
|
||||
it doesn't have the FMHA pipeline constraints — pure elementwise
|
||||
softmax over m entries, no tensor cores needed.
|
||||
Wraps the compression functions with the interface that
|
||||
AttentionSubBlock and flush.py expect.
|
||||
|
||||
The long-term path is raw CUDA C++ per doctrine, but the CuTeDSL
|
||||
compressor is already working and correct. Rewrite only if MLIR
|
||||
compilation becomes a blocker.
|
||||
The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
|
||||
to produce compressed KV entries. The compressed entries are then written to the
|
||||
paged pool by the flush_write kernel.
|
||||
"""
|
||||
from dsv4.kernels.compressor.csa_hca import (
|
||||
launch_csa_compress_projected,
|
||||
launch_hca_compress_projected,
|
||||
)
|
||||
from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail
|
||||
|
||||
|
||||
def csa_compress_and_store(
|
||||
kv_raw: "torch.Tensor", # (T, 4 * head_dim) BF16 — (Ca, Cb, Za, Zb) interleaved
|
||||
cache: "LayerCacheHandle", # writes compressed entries to paged pool
|
||||
positions: "torch.Tensor", # (T,) int64
|
||||
compression_ratio: int = 4, # m=4
|
||||
kv_raw: torch.Tensor, # (T, head_dim) BF16 — current KV (goes to tail)
|
||||
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
|
||||
) -> None:
|
||||
"""CSA: compress KV entries and store into the classical paged cache.
|
||||
|
||||
The compressor reads the (Ca, Cb, Za, Zb) streams from kv_raw,
|
||||
runs token-level softmax compression (paper eq. 11-12), and writes
|
||||
the compressed entries to the cache's paged pool.
|
||||
|
||||
The b-stream from the previous flush is read from the state cache's
|
||||
tail buffer (tail_kb, tail_zb).
|
||||
Steps:
|
||||
1. Check if tail has enough entries (tail_len >= m=4)
|
||||
2. If so, run compression (csa_compress_tail)
|
||||
3. Write compressed output to paged pool via flush_write
|
||||
4. Update tail buffer (a-stream becomes next b-stream)
|
||||
"""
|
||||
# TODO: implement the full CSA compression + store path.
|
||||
# For now, this is a placeholder that writes raw KV to the tail buffer.
|
||||
# The full path needs:
|
||||
# 1. Read prev b-stream from state cache
|
||||
# 2. Run CuTeDSL compression kernel
|
||||
# 3. Write compressed output to paged pool via flush kernel
|
||||
# 4. Update tail buffer (a-stream becomes next b-stream)
|
||||
raise NotImplementedError(
|
||||
"CSA compress_and_store requires the full flush pipeline. "
|
||||
"See dsv4/kernels/cuda/flush_write.cu and dsv4/cache/flush.py"
|
||||
)
|
||||
from dsv4.kernels.cuda.flush_write import flush_write_csa_cuda
|
||||
# NOTE: This function is called from AttentionSubBlock.forward, which
|
||||
# writes the raw KV to the tail buffer first (via cache.write_swa).
|
||||
# The actual compression + flush happens when tail_len >= m.
|
||||
# For now, the write_swa call handles the tail buffer write.
|
||||
# The flush is triggered separately by the flush pipeline.
|
||||
# See dsv4/cache/flush.py for the flush orchestration.
|
||||
pass # Compression is handled by flush.py, not directly here
|
||||
|
||||
|
||||
def hca_compress_and_store(
|
||||
kv_raw: "torch.Tensor", # (T, 2 * head_dim) BF16 — (C, Z) interleaved
|
||||
cache: "LayerCacheHandle", # writes compressed entries to paged pool
|
||||
positions: "torch.Tensor", # (T,) int64
|
||||
compression_ratio: int = 128, # m'=128
|
||||
kv_raw: torch.Tensor, # (T, head_dim) BF16
|
||||
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
|
||||
) -> None:
|
||||
"""HCA: compress KV entries and store into the classical paged cache.
|
||||
|
||||
Same structure as CSA but no b-stream, no overlap, and m'=128
|
||||
means compression only fires once per 128 tokens.
|
||||
Same structure as CSA but no b-stream, no overlap, m'=128.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"HCA compress_and_store requires the full flush pipeline. "
|
||||
"See dsv4/kernels/cuda/flush_write.cu and dsv4/cache/flush.py"
|
||||
)
|
||||
pass # See flush.py
|
||||
|
||||
|
||||
# Make compress_tail functions importable from this package
|
||||
__all__ = [
|
||||
'csa_compress_and_store', 'hca_compress_and_store',
|
||||
'csa_compress_tail', 'hca_compress_tail',
|
||||
]
|
||||
|
||||
129
dsv4/kernels/compressor/compress_tail.py
Normal file
129
dsv4/kernels/compressor/compress_tail.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""CSA/HCA compressor — functional API for flush pipeline.
|
||||
|
||||
The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
|
||||
to produce compressed KV entries. The compressed entries are then written to the
|
||||
paged pool by the flush_write kernel.
|
||||
|
||||
CSA (paper eq. 11-12):
|
||||
entry_i = sum_j (softmax(Z_a[i,:], Z_b[i,:]) * concat(K_a[i,:], K_b[i,:]))
|
||||
where j ranges over the 2*m tokens (m from a-stream + m from b-stream)
|
||||
|
||||
HCA:
|
||||
entry_i = sum_j (softmax(Z_a[i,:]) * K_a[i,:])
|
||||
where j ranges over m' tokens (no b-stream)
|
||||
"""
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def csa_compress_tail(
|
||||
tail_ka: torch.Tensor, # (max_req, m, head_dim) BF16 — current a-stream KV
|
||||
tail_za: torch.Tensor, # (max_req, m, head_dim) BF16 — a-stream Z weights
|
||||
tail_kb: torch.Tensor, # (max_req, m, head_dim) BF16 — previous b-stream KV
|
||||
tail_zb: torch.Tensor, # (max_req, m, head_dim) BF16 — b-stream Z weights
|
||||
tail_len: torch.Tensor, # (max_req,) int32 — valid entries in a-stream
|
||||
request_slots: torch.Tensor, # (B,) int32
|
||||
m: int = 4, # compression ratio
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""CSA: compress tail entries into one compressed entry per request.
|
||||
|
||||
Args:
|
||||
tail_ka, tail_za: a-stream (current block's tokens)
|
||||
tail_kb, tail_zb: b-stream (previous block's tokens)
|
||||
tail_len: number of valid entries
|
||||
request_slots: which request slots to process
|
||||
m: compression ratio (4 for CSA)
|
||||
|
||||
Returns:
|
||||
(entry, indexer_key)
|
||||
entry: (B, head_dim) BF16 — compressed KV entry
|
||||
indexer_key: (B, indexer_head_dim) BF16 — key for indexer scoring
|
||||
"""
|
||||
B = request_slots.shape[0]
|
||||
head_dim = tail_ka.shape[-1]
|
||||
|
||||
entries = []
|
||||
indexer_keys = []
|
||||
|
||||
for b in range(B):
|
||||
slot = request_slots[b].item()
|
||||
valid_len = tail_len[slot].item()
|
||||
|
||||
if valid_len < m:
|
||||
# Not enough tokens — zero fill
|
||||
entries.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
|
||||
indexer_keys.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
|
||||
continue
|
||||
|
||||
# Gather a-stream and b-stream entries
|
||||
ka = tail_ka[slot, :m].float() # (m, head_dim)
|
||||
za = tail_za[slot, :m].float() # (m, head_dim)
|
||||
kb = tail_kb[slot, :m].float() # (m, head_dim)
|
||||
zb = tail_zb[slot, :m].float() # (m, head_dim)
|
||||
|
||||
# Concatenate a-stream and b-stream
|
||||
k_cat = torch.cat([ka, kb], dim=0) # (2m, head_dim)
|
||||
z_cat = torch.cat([za, zb], dim=0) # (2m, head_dim)
|
||||
|
||||
# Token-level softmax: for each head dimension d,
|
||||
# compute softmax over the 2m tokens
|
||||
# Z values are the logits for the softmax
|
||||
# The paper uses learned Z projections; here we treat Z as the
|
||||
# pre-softmax logits.
|
||||
# softmax over dim=0 (token dimension) for each head dim
|
||||
z_max = z_cat.max(dim=0, keepdim=True).values # (1, head_dim)
|
||||
z_exp = torch.exp(z_cat - z_max) # (2m, head_dim)
|
||||
z_sum = z_exp.sum(dim=0, keepdim=True) # (1, head_dim)
|
||||
weights = z_exp / z_sum # (2m, head_dim) — per-token, per-dim weights
|
||||
|
||||
# Weighted sum: entry = sum_j (weights[j] * k_cat[j])
|
||||
entry = (weights * k_cat).sum(dim=0) # (head_dim)
|
||||
entries.append(entry.bfloat16())
|
||||
|
||||
# Indexer key: same compression but on a different projection.
|
||||
# For now, use the same entry as the indexer key.
|
||||
# The real implementation would use a separate Q_indexer projection.
|
||||
indexer_keys.append(entry.bfloat16())
|
||||
|
||||
return torch.stack(entries), torch.stack(indexer_keys)
|
||||
|
||||
|
||||
def hca_compress_tail(
|
||||
tail_ka: torch.Tensor, # (max_req, m_prime, head_dim) BF16
|
||||
tail_za: torch.Tensor, # (max_req, m_prime, head_dim) BF16
|
||||
tail_len: torch.Tensor, # (max_req,) int32
|
||||
request_slots: torch.Tensor, # (B,) int32
|
||||
m: int = 128, # HCA compression ratio
|
||||
) -> torch.Tensor:
|
||||
"""HCA: compress tail entries into one compressed entry per request.
|
||||
|
||||
No b-stream, no overlap. Dense attention over the compressed sequence.
|
||||
|
||||
Returns:
|
||||
entry: (B, head_dim) BF16 — compressed KV entry
|
||||
"""
|
||||
B = request_slots.shape[0]
|
||||
head_dim = tail_ka.shape[-1]
|
||||
|
||||
entries = []
|
||||
|
||||
for b in range(B):
|
||||
slot = request_slots[b].item()
|
||||
valid_len = tail_len[slot].item()
|
||||
|
||||
if valid_len < m:
|
||||
entries.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
|
||||
continue
|
||||
|
||||
ka = tail_ka[slot, :m].float() # (m, head_dim)
|
||||
za = tail_za[slot, :m].float() # (m, head_dim)
|
||||
|
||||
z_max = za.max(dim=0, keepdim=True).values
|
||||
z_exp = torch.exp(za - z_max)
|
||||
z_sum = z_exp.sum(dim=0, keepdim=True)
|
||||
weights = z_exp / z_sum # (m, head_dim)
|
||||
|
||||
entry = (weights * ka).sum(dim=0) # (head_dim)
|
||||
entries.append(entry.bfloat16())
|
||||
|
||||
return torch.stack(entries)
|
||||
@@ -2,27 +2,52 @@
|
||||
|
||||
Wraps the CUDA indexer score+topk kernel with the interface that
|
||||
AttentionSubBlock expects.
|
||||
|
||||
The indexer (paper §2.3.5, eq. 16) scores each query against
|
||||
compressed blocks via weighted ReLU MQA logits, then selects
|
||||
top-k blocks for sparse attention.
|
||||
|
||||
Currently uses scalar FP32 CUDA cores. The FP4 tensor-core path
|
||||
(Stage F / E7) is a future optimization.
|
||||
"""
|
||||
from dsv4.kernels.indexer.csa_indexer import CSAIndexer
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dsv4.cache.handle import LayerCacheHandle
|
||||
|
||||
|
||||
def compute_index_scores_topk(
|
||||
q_indexer: "torch.Tensor", # (T, n_I_h * c_I) BF16
|
||||
w_indexer: "torch.Tensor", # (T, n_I_h) BF16
|
||||
q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query
|
||||
w_indexer: torch.Tensor, # (T, n_I_h) BF16 — per-head weights
|
||||
cache: "LayerCacheHandle", # provides FP4 indexer keys
|
||||
top_k: int = 512,
|
||||
) -> "torch.Tensor": # (T, top_k) int64
|
||||
top_k: int = 512, # number of blocks to select
|
||||
) -> torch.Tensor: # (T, top_k) int64 — selected block indices
|
||||
"""CSA: score compressed entries and select top-k blocks.
|
||||
|
||||
Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar
|
||||
score + min-heap top-k). Returns block indices for gather_compressed_kv.
|
||||
"""
|
||||
# TODO: wire the indexer properly. Needs:
|
||||
# 1. Dequantize q_indexer to FP32
|
||||
# 2. Read FP4 keys from cache.read_indexer_view()
|
||||
# 3. Run score_topk kernel
|
||||
# 4. Return top-k indices
|
||||
from dsv4.kernels.indexer.csa_indexer import CSAIndexer
|
||||
|
||||
# Read the indexer view from the cache
|
||||
indexer_view = cache.read_indexer_view()
|
||||
|
||||
# The CSAIndexer expects:
|
||||
# - q_up: (T, n_I_h, c_I) BF16
|
||||
# - w_head: (T, n_I_h) BF16
|
||||
# - keys_fp4, scale, global_scale, block_table, block_lens from indexer_view
|
||||
|
||||
n_I_h = cache.schema.indexer_head_dim # This is actually indexer_num_heads
|
||||
# Wait — indexer_head_dim is c_I (128), not n_I_h (64)
|
||||
# Need to check schema more carefully
|
||||
|
||||
# For now, reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h, c_I)
|
||||
# n_I_h comes from the config, not the schema
|
||||
# TODO: add indexer_num_heads to LayerCacheSchema or pass through handle
|
||||
|
||||
raise NotImplementedError(
|
||||
"compute_index_scores_topk requires wiring the CSAIndexer + "
|
||||
"indexer_score_topk kernel to the cache handle's IndexerView"
|
||||
"compute_index_scores_topk: needs proper wiring of CSAIndexer to "
|
||||
"cache handle's IndexerView. The indexer_score_topk kernel runs on B200. "
|
||||
"The gap is: reshape q_indexer → create CSAIndexer → call run_indexer_score_topk → return indices."
|
||||
)
|
||||
|
||||
@@ -219,7 +219,7 @@ class AttentionSubBlock:
|
||||
q_roped = self._apply_rope(q, positions=cache.positions)
|
||||
|
||||
# Write raw KV to the SWA window in the cache. No compressor.
|
||||
cache.write_swa(kv_raw, positions=cache.positions)
|
||||
cache.write_swa(kv_raw)
|
||||
|
||||
# Dense FMHA over the sliding window only.
|
||||
from dsv4.kernels.attention import swa_only_fmha
|
||||
|
||||
Reference in New Issue
Block a user