Files
nvfp4-megamoe-kernel/dsv4/kernels/compressor/compress_tail.py
biondizzle daf84524ac E2/E3: compressor bridge, indexer bridge, flush pipeline wiring
- compress_tail.py: PyTorch reference CSA/HCA compression
  (token-level softmax over m/m' entries, paper eq. 11-12)
- compressor/__init__.py: csa_compress_and_store, hca_compress_and_store
  bridges (compression deferred to flush pipeline)
- indexer/__init__.py: compute_index_scores_topk bridge (NotImplemented)
- Fixed attention.py: removed extra positions arg to write_swa
2026-05-30 21:16:54 +00:00

130 lines
4.9 KiB
Python

"""CSA/HCA compressor — functional API for flush pipeline.
The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
to produce compressed KV entries. The compressed entries are then written to the
paged pool by the flush_write kernel.
CSA (paper eq. 11-12):
entry_i = sum_j (softmax(Z_a[i,:], Z_b[i,:]) * concat(K_a[i,:], K_b[i,:]))
where j ranges over the 2*m tokens (m from a-stream + m from b-stream)
HCA:
entry_i = sum_j (softmax(Z_a[i,:]) * K_a[i,:])
where j ranges over m' tokens (no b-stream)
"""
import torch
from typing import Optional
def csa_compress_tail(
tail_ka: torch.Tensor, # (max_req, m, head_dim) BF16 — current a-stream KV
tail_za: torch.Tensor, # (max_req, m, head_dim) BF16 — a-stream Z weights
tail_kb: torch.Tensor, # (max_req, m, head_dim) BF16 — previous b-stream KV
tail_zb: torch.Tensor, # (max_req, m, head_dim) BF16 — b-stream Z weights
tail_len: torch.Tensor, # (max_req,) int32 — valid entries in a-stream
request_slots: torch.Tensor, # (B,) int32
m: int = 4, # compression ratio
) -> tuple[torch.Tensor, torch.Tensor]:
"""CSA: compress tail entries into one compressed entry per request.
Args:
tail_ka, tail_za: a-stream (current block's tokens)
tail_kb, tail_zb: b-stream (previous block's tokens)
tail_len: number of valid entries
request_slots: which request slots to process
m: compression ratio (4 for CSA)
Returns:
(entry, indexer_key)
entry: (B, head_dim) BF16 — compressed KV entry
indexer_key: (B, indexer_head_dim) BF16 — key for indexer scoring
"""
B = request_slots.shape[0]
head_dim = tail_ka.shape[-1]
entries = []
indexer_keys = []
for b in range(B):
slot = request_slots[b].item()
valid_len = tail_len[slot].item()
if valid_len < m:
# Not enough tokens — zero fill
entries.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
indexer_keys.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
continue
# Gather a-stream and b-stream entries
ka = tail_ka[slot, :m].float() # (m, head_dim)
za = tail_za[slot, :m].float() # (m, head_dim)
kb = tail_kb[slot, :m].float() # (m, head_dim)
zb = tail_zb[slot, :m].float() # (m, head_dim)
# Concatenate a-stream and b-stream
k_cat = torch.cat([ka, kb], dim=0) # (2m, head_dim)
z_cat = torch.cat([za, zb], dim=0) # (2m, head_dim)
# Token-level softmax: for each head dimension d,
# compute softmax over the 2m tokens
# Z values are the logits for the softmax
# The paper uses learned Z projections; here we treat Z as the
# pre-softmax logits.
# softmax over dim=0 (token dimension) for each head dim
z_max = z_cat.max(dim=0, keepdim=True).values # (1, head_dim)
z_exp = torch.exp(z_cat - z_max) # (2m, head_dim)
z_sum = z_exp.sum(dim=0, keepdim=True) # (1, head_dim)
weights = z_exp / z_sum # (2m, head_dim) — per-token, per-dim weights
# Weighted sum: entry = sum_j (weights[j] * k_cat[j])
entry = (weights * k_cat).sum(dim=0) # (head_dim)
entries.append(entry.bfloat16())
# Indexer key: same compression but on a different projection.
# For now, use the same entry as the indexer key.
# The real implementation would use a separate Q_indexer projection.
indexer_keys.append(entry.bfloat16())
return torch.stack(entries), torch.stack(indexer_keys)
def hca_compress_tail(
tail_ka: torch.Tensor, # (max_req, m_prime, head_dim) BF16
tail_za: torch.Tensor, # (max_req, m_prime, head_dim) BF16
tail_len: torch.Tensor, # (max_req,) int32
request_slots: torch.Tensor, # (B,) int32
m: int = 128, # HCA compression ratio
) -> torch.Tensor:
"""HCA: compress tail entries into one compressed entry per request.
No b-stream, no overlap. Dense attention over the compressed sequence.
Returns:
entry: (B, head_dim) BF16 — compressed KV entry
"""
B = request_slots.shape[0]
head_dim = tail_ka.shape[-1]
entries = []
for b in range(B):
slot = request_slots[b].item()
valid_len = tail_len[slot].item()
if valid_len < m:
entries.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
continue
ka = tail_ka[slot, :m].float() # (m, head_dim)
za = tail_za[slot, :m].float() # (m, head_dim)
z_max = za.max(dim=0, keepdim=True).values
z_exp = torch.exp(za - z_max)
z_sum = z_exp.sum(dim=0, keepdim=True)
weights = z_exp / z_sum # (m, head_dim)
entry = (weights * ka).sum(dim=0) # (head_dim)
entries.append(entry.bfloat16())
return torch.stack(entries)