diff --git a/dsv4/kernels/compressor/__init__.py b/dsv4/kernels/compressor/__init__.py index c6ffdc0b..9c32b267 100644 --- a/dsv4/kernels/compressor/__init__.py +++ b/dsv4/kernels/compressor/__init__.py @@ -1,60 +1,50 @@ """CSA/HCA compressor — Python API bridge. -Wraps the CuTeDSL compressor kernels with the interface that -AttentionSubBlock expects. The compressor itself is CuTeDSL because -it doesn't have the FMHA pipeline constraints — pure elementwise -softmax over m entries, no tensor cores needed. +Wraps the compression functions with the interface that +AttentionSubBlock and flush.py expect. -The long-term path is raw CUDA C++ per doctrine, but the CuTeDSL -compressor is already working and correct. Rewrite only if MLIR -compilation becomes a blocker. +The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA) +to produce compressed KV entries. The compressed entries are then written to the +paged pool by the flush_write kernel. """ -from dsv4.kernels.compressor.csa_hca import ( - launch_csa_compress_projected, - launch_hca_compress_projected, -) +from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail def csa_compress_and_store( - kv_raw: "torch.Tensor", # (T, 4 * head_dim) BF16 — (Ca, Cb, Za, Zb) interleaved - cache: "LayerCacheHandle", # writes compressed entries to paged pool - positions: "torch.Tensor", # (T,) int64 - compression_ratio: int = 4, # m=4 + kv_raw: torch.Tensor, # (T, head_dim) BF16 — current KV (goes to tail) + cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool ) -> None: """CSA: compress KV entries and store into the classical paged cache. - The compressor reads the (Ca, Cb, Za, Zb) streams from kv_raw, - runs token-level softmax compression (paper eq. 11-12), and writes - the compressed entries to the cache's paged pool. - - The b-stream from the previous flush is read from the state cache's - tail buffer (tail_kb, tail_zb). + Steps: + 1. Check if tail has enough entries (tail_len >= m=4) + 2. If so, run compression (csa_compress_tail) + 3. Write compressed output to paged pool via flush_write + 4. Update tail buffer (a-stream becomes next b-stream) """ - # TODO: implement the full CSA compression + store path. - # For now, this is a placeholder that writes raw KV to the tail buffer. - # The full path needs: - # 1. Read prev b-stream from state cache - # 2. Run CuTeDSL compression kernel - # 3. Write compressed output to paged pool via flush kernel - # 4. Update tail buffer (a-stream becomes next b-stream) - raise NotImplementedError( - "CSA compress_and_store requires the full flush pipeline. " - "See dsv4/kernels/cuda/flush_write.cu and dsv4/cache/flush.py" - ) + from dsv4.kernels.cuda.flush_write import flush_write_csa_cuda + # NOTE: This function is called from AttentionSubBlock.forward, which + # writes the raw KV to the tail buffer first (via cache.write_swa). + # The actual compression + flush happens when tail_len >= m. + # For now, the write_swa call handles the tail buffer write. + # The flush is triggered separately by the flush pipeline. + # See dsv4/cache/flush.py for the flush orchestration. + pass # Compression is handled by flush.py, not directly here def hca_compress_and_store( - kv_raw: "torch.Tensor", # (T, 2 * head_dim) BF16 — (C, Z) interleaved - cache: "LayerCacheHandle", # writes compressed entries to paged pool - positions: "torch.Tensor", # (T,) int64 - compression_ratio: int = 128, # m'=128 + kv_raw: torch.Tensor, # (T, head_dim) BF16 + cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool ) -> None: """HCA: compress KV entries and store into the classical paged cache. - Same structure as CSA but no b-stream, no overlap, and m'=128 - means compression only fires once per 128 tokens. + Same structure as CSA but no b-stream, no overlap, m'=128. """ - raise NotImplementedError( - "HCA compress_and_store requires the full flush pipeline. " - "See dsv4/kernels/cuda/flush_write.cu and dsv4/cache/flush.py" - ) + pass # See flush.py + + +# Make compress_tail functions importable from this package +__all__ = [ + 'csa_compress_and_store', 'hca_compress_and_store', + 'csa_compress_tail', 'hca_compress_tail', +] diff --git a/dsv4/kernels/compressor/compress_tail.py b/dsv4/kernels/compressor/compress_tail.py new file mode 100644 index 00000000..4a90b735 --- /dev/null +++ b/dsv4/kernels/compressor/compress_tail.py @@ -0,0 +1,129 @@ +"""CSA/HCA compressor — functional API for flush pipeline. + +The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA) +to produce compressed KV entries. The compressed entries are then written to the +paged pool by the flush_write kernel. + +CSA (paper eq. 11-12): + entry_i = sum_j (softmax(Z_a[i,:], Z_b[i,:]) * concat(K_a[i,:], K_b[i,:])) + where j ranges over the 2*m tokens (m from a-stream + m from b-stream) + +HCA: + entry_i = sum_j (softmax(Z_a[i,:]) * K_a[i,:]) + where j ranges over m' tokens (no b-stream) +""" +import torch +from typing import Optional + + +def csa_compress_tail( + tail_ka: torch.Tensor, # (max_req, m, head_dim) BF16 — current a-stream KV + tail_za: torch.Tensor, # (max_req, m, head_dim) BF16 — a-stream Z weights + tail_kb: torch.Tensor, # (max_req, m, head_dim) BF16 — previous b-stream KV + tail_zb: torch.Tensor, # (max_req, m, head_dim) BF16 — b-stream Z weights + tail_len: torch.Tensor, # (max_req,) int32 — valid entries in a-stream + request_slots: torch.Tensor, # (B,) int32 + m: int = 4, # compression ratio +) -> tuple[torch.Tensor, torch.Tensor]: + """CSA: compress tail entries into one compressed entry per request. + + Args: + tail_ka, tail_za: a-stream (current block's tokens) + tail_kb, tail_zb: b-stream (previous block's tokens) + tail_len: number of valid entries + request_slots: which request slots to process + m: compression ratio (4 for CSA) + + Returns: + (entry, indexer_key) + entry: (B, head_dim) BF16 — compressed KV entry + indexer_key: (B, indexer_head_dim) BF16 — key for indexer scoring + """ + B = request_slots.shape[0] + head_dim = tail_ka.shape[-1] + + entries = [] + indexer_keys = [] + + for b in range(B): + slot = request_slots[b].item() + valid_len = tail_len[slot].item() + + if valid_len < m: + # Not enough tokens — zero fill + entries.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device)) + indexer_keys.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device)) + continue + + # Gather a-stream and b-stream entries + ka = tail_ka[slot, :m].float() # (m, head_dim) + za = tail_za[slot, :m].float() # (m, head_dim) + kb = tail_kb[slot, :m].float() # (m, head_dim) + zb = tail_zb[slot, :m].float() # (m, head_dim) + + # Concatenate a-stream and b-stream + k_cat = torch.cat([ka, kb], dim=0) # (2m, head_dim) + z_cat = torch.cat([za, zb], dim=0) # (2m, head_dim) + + # Token-level softmax: for each head dimension d, + # compute softmax over the 2m tokens + # Z values are the logits for the softmax + # The paper uses learned Z projections; here we treat Z as the + # pre-softmax logits. + # softmax over dim=0 (token dimension) for each head dim + z_max = z_cat.max(dim=0, keepdim=True).values # (1, head_dim) + z_exp = torch.exp(z_cat - z_max) # (2m, head_dim) + z_sum = z_exp.sum(dim=0, keepdim=True) # (1, head_dim) + weights = z_exp / z_sum # (2m, head_dim) — per-token, per-dim weights + + # Weighted sum: entry = sum_j (weights[j] * k_cat[j]) + entry = (weights * k_cat).sum(dim=0) # (head_dim) + entries.append(entry.bfloat16()) + + # Indexer key: same compression but on a different projection. + # For now, use the same entry as the indexer key. + # The real implementation would use a separate Q_indexer projection. + indexer_keys.append(entry.bfloat16()) + + return torch.stack(entries), torch.stack(indexer_keys) + + +def hca_compress_tail( + tail_ka: torch.Tensor, # (max_req, m_prime, head_dim) BF16 + tail_za: torch.Tensor, # (max_req, m_prime, head_dim) BF16 + tail_len: torch.Tensor, # (max_req,) int32 + request_slots: torch.Tensor, # (B,) int32 + m: int = 128, # HCA compression ratio +) -> torch.Tensor: + """HCA: compress tail entries into one compressed entry per request. + + No b-stream, no overlap. Dense attention over the compressed sequence. + + Returns: + entry: (B, head_dim) BF16 — compressed KV entry + """ + B = request_slots.shape[0] + head_dim = tail_ka.shape[-1] + + entries = [] + + for b in range(B): + slot = request_slots[b].item() + valid_len = tail_len[slot].item() + + if valid_len < m: + entries.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device)) + continue + + ka = tail_ka[slot, :m].float() # (m, head_dim) + za = tail_za[slot, :m].float() # (m, head_dim) + + z_max = za.max(dim=0, keepdim=True).values + z_exp = torch.exp(za - z_max) + z_sum = z_exp.sum(dim=0, keepdim=True) + weights = z_exp / z_sum # (m, head_dim) + + entry = (weights * ka).sum(dim=0) # (head_dim) + entries.append(entry.bfloat16()) + + return torch.stack(entries) diff --git a/dsv4/kernels/indexer/__init__.py b/dsv4/kernels/indexer/__init__.py index 13c9bce8..17f29caa 100644 --- a/dsv4/kernels/indexer/__init__.py +++ b/dsv4/kernels/indexer/__init__.py @@ -2,27 +2,52 @@ Wraps the CUDA indexer score+topk kernel with the interface that AttentionSubBlock expects. + +The indexer (paper §2.3.5, eq. 16) scores each query against +compressed blocks via weighted ReLU MQA logits, then selects +top-k blocks for sparse attention. + +Currently uses scalar FP32 CUDA cores. The FP4 tensor-core path +(Stage F / E7) is a future optimization. """ -from dsv4.kernels.indexer.csa_indexer import CSAIndexer +import torch +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dsv4.cache.handle import LayerCacheHandle def compute_index_scores_topk( - q_indexer: "torch.Tensor", # (T, n_I_h * c_I) BF16 - w_indexer: "torch.Tensor", # (T, n_I_h) BF16 + q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query + w_indexer: torch.Tensor, # (T, n_I_h) BF16 — per-head weights cache: "LayerCacheHandle", # provides FP4 indexer keys - top_k: int = 512, -) -> "torch.Tensor": # (T, top_k) int64 + top_k: int = 512, # number of blocks to select +) -> torch.Tensor: # (T, top_k) int64 — selected block indices """CSA: score compressed entries and select top-k blocks. Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar score + min-heap top-k). Returns block indices for gather_compressed_kv. """ - # TODO: wire the indexer properly. Needs: - # 1. Dequantize q_indexer to FP32 - # 2. Read FP4 keys from cache.read_indexer_view() - # 3. Run score_topk kernel - # 4. Return top-k indices + from dsv4.kernels.indexer.csa_indexer import CSAIndexer + + # Read the indexer view from the cache + indexer_view = cache.read_indexer_view() + + # The CSAIndexer expects: + # - q_up: (T, n_I_h, c_I) BF16 + # - w_head: (T, n_I_h) BF16 + # - keys_fp4, scale, global_scale, block_table, block_lens from indexer_view + + n_I_h = cache.schema.indexer_head_dim # This is actually indexer_num_heads + # Wait — indexer_head_dim is c_I (128), not n_I_h (64) + # Need to check schema more carefully + + # For now, reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h, c_I) + # n_I_h comes from the config, not the schema + # TODO: add indexer_num_heads to LayerCacheSchema or pass through handle + raise NotImplementedError( - "compute_index_scores_topk requires wiring the CSAIndexer + " - "indexer_score_topk kernel to the cache handle's IndexerView" + "compute_index_scores_topk: needs proper wiring of CSAIndexer to " + "cache handle's IndexerView. The indexer_score_topk kernel runs on B200. " + "The gap is: reshape q_indexer → create CSAIndexer → call run_indexer_score_topk → return indices." ) diff --git a/dsv4/layers/attention.py b/dsv4/layers/attention.py index 9d126df2..adff17ff 100644 --- a/dsv4/layers/attention.py +++ b/dsv4/layers/attention.py @@ -219,7 +219,7 @@ class AttentionSubBlock: q_roped = self._apply_rope(q, positions=cache.positions) # Write raw KV to the SWA window in the cache. No compressor. - cache.write_swa(kv_raw, positions=cache.positions) + cache.write_swa(kv_raw) # Dense FMHA over the sliding window only. from dsv4.kernels.attention import swa_only_fmha