Files
nvfp4-megamoe-kernel/dsv4/_archive/kernels/compressor/compress_tail.py
biondizzle f3b551956d Cleanup Step 2: Archive Lineage P code, fix broken imports
- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
  - model/{dsv4,mtp,layer,layer_schedule}
  - layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
  - cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
  - kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
  - ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
  - reference/{attention,compressor,csa_attention,moe_pipeline}
  - kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
  - test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
2026-06-02 19:27:07 +00:00

130 lines
4.9 KiB
Python

"""CSA/HCA compressor — functional API for flush pipeline.
The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
to produce compressed KV entries. The compressed entries are then written to the
paged pool by the flush_write kernel.
CSA (paper eq. 11-12):
entry_i = sum_j (softmax(Z_a[i,:], Z_b[i,:]) * concat(K_a[i,:], K_b[i,:]))
where j ranges over the 2*m tokens (m from a-stream + m from b-stream)
HCA:
entry_i = sum_j (softmax(Z_a[i,:]) * K_a[i,:])
where j ranges over m' tokens (no b-stream)
"""
import torch
from typing import Optional
def csa_compress_tail(
tail_ka: torch.Tensor, # (max_req, m, head_dim) BF16 — current a-stream KV
tail_za: torch.Tensor, # (max_req, m, head_dim) BF16 — a-stream Z weights
tail_kb: torch.Tensor, # (max_req, m, head_dim) BF16 — previous b-stream KV
tail_zb: torch.Tensor, # (max_req, m, head_dim) BF16 — b-stream Z weights
tail_len: torch.Tensor, # (max_req,) int32 — valid entries in a-stream
request_slots: torch.Tensor, # (B,) int32
m: int = 4, # compression ratio
) -> tuple[torch.Tensor, torch.Tensor]:
"""CSA: compress tail entries into one compressed entry per request.
Args:
tail_ka, tail_za: a-stream (current block's tokens)
tail_kb, tail_zb: b-stream (previous block's tokens)
tail_len: number of valid entries
request_slots: which request slots to process
m: compression ratio (4 for CSA)
Returns:
(entry, indexer_key)
entry: (B, head_dim) BF16 — compressed KV entry
indexer_key: (B, indexer_head_dim) BF16 — key for indexer scoring
"""
B = request_slots.shape[0]
head_dim = tail_ka.shape[-1]
entries = []
indexer_keys = []
for b in range(B):
slot = request_slots[b].item()
valid_len = tail_len[slot].item()
if valid_len < m:
# Not enough tokens — zero fill
entries.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
indexer_keys.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
continue
# Gather a-stream and b-stream entries
ka = tail_ka[slot, :m].float() # (m, head_dim)
za = tail_za[slot, :m].float() # (m, head_dim)
kb = tail_kb[slot, :m].float() # (m, head_dim)
zb = tail_zb[slot, :m].float() # (m, head_dim)
# Concatenate a-stream and b-stream
k_cat = torch.cat([ka, kb], dim=0) # (2m, head_dim)
z_cat = torch.cat([za, zb], dim=0) # (2m, head_dim)
# Token-level softmax: for each head dimension d,
# compute softmax over the 2m tokens
# Z values are the logits for the softmax
# The paper uses learned Z projections; here we treat Z as the
# pre-softmax logits.
# softmax over dim=0 (token dimension) for each head dim
z_max = z_cat.max(dim=0, keepdim=True).values # (1, head_dim)
z_exp = torch.exp(z_cat - z_max) # (2m, head_dim)
z_sum = z_exp.sum(dim=0, keepdim=True) # (1, head_dim)
weights = z_exp / z_sum # (2m, head_dim) — per-token, per-dim weights
# Weighted sum: entry = sum_j (weights[j] * k_cat[j])
entry = (weights * k_cat).sum(dim=0) # (head_dim)
entries.append(entry.bfloat16())
# Indexer key: same compression but on a different projection.
# For now, use the same entry as the indexer key.
# The real implementation would use a separate Q_indexer projection.
indexer_keys.append(entry.bfloat16())
return torch.stack(entries), torch.stack(indexer_keys)
def hca_compress_tail(
tail_ka: torch.Tensor, # (max_req, m_prime, head_dim) BF16
tail_za: torch.Tensor, # (max_req, m_prime, head_dim) BF16
tail_len: torch.Tensor, # (max_req,) int32
request_slots: torch.Tensor, # (B,) int32
m: int = 128, # HCA compression ratio
) -> torch.Tensor:
"""HCA: compress tail entries into one compressed entry per request.
No b-stream, no overlap. Dense attention over the compressed sequence.
Returns:
entry: (B, head_dim) BF16 — compressed KV entry
"""
B = request_slots.shape[0]
head_dim = tail_ka.shape[-1]
entries = []
for b in range(B):
slot = request_slots[b].item()
valid_len = tail_len[slot].item()
if valid_len < m:
entries.append(torch.zeros(head_dim, dtype=torch.bfloat16, device=tail_ka.device))
continue
ka = tail_ka[slot, :m].float() # (m, head_dim)
za = tail_za[slot, :m].float() # (m, head_dim)
z_max = za.max(dim=0, keepdim=True).values
z_exp = torch.exp(za - z_max)
z_sum = z_exp.sum(dim=0, keepdim=True)
weights = z_exp / z_sum # (m, head_dim)
entry = (weights * ka).sum(dim=0) # (head_dim)
entries.append(entry.bfloat16())
return torch.stack(entries)