Files
nvfp4-megamoe-kernel/dsv4/kernels/compressor/__init__.py
biondizzle 9d88769f5f Wire indexer compute_index_scores_topk + fix compressor imports
- indexer/__init__.py: compute_index_scores_topk now calls
  run_indexer_score_topk with proper tensor reshaping
- compressor/__init__.py: added torch import, fixed csa_compress_tail
  and hca_compress_tail imports for flush.py
- Full flush pipeline now importable end-to-end
2026-05-30 21:19:06 +00:00

57 lines
2.1 KiB
Python

"""CSA/HCA compressor — Python API bridge.
Wraps the compression functions with the interface that
AttentionSubBlock and flush.py expect.
The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
to produce compressed KV entries. The compressed entries are then written to the
paged pool by the flush_write kernel.
"""
import torch
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dsv4.cache.handle import LayerCacheHandle
from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail
def csa_compress_and_store(
kv_raw: torch.Tensor, # (T, head_dim) BF16 — current KV (goes to tail)
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
) -> None:
"""CSA: compress KV entries and store into the classical paged cache.
Steps:
1. Check if tail has enough entries (tail_len >= m=4)
2. If so, run compression (csa_compress_tail)
3. Write compressed output to paged pool via flush_write
4. Update tail buffer (a-stream becomes next b-stream)
"""
from dsv4.kernels.cuda.flush_write import flush_write_csa_cuda
# NOTE: This function is called from AttentionSubBlock.forward, which
# writes the raw KV to the tail buffer first (via cache.write_swa).
# The actual compression + flush happens when tail_len >= m.
# For now, the write_swa call handles the tail buffer write.
# The flush is triggered separately by the flush pipeline.
# See dsv4/cache/flush.py for the flush orchestration.
pass # Compression is handled by flush.py, not directly here
def hca_compress_and_store(
kv_raw: torch.Tensor, # (T, head_dim) BF16
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
) -> None:
"""HCA: compress KV entries and store into the classical paged cache.
Same structure as CSA but no b-stream, no overlap, m'=128.
"""
pass # See flush.py
# Make compress_tail functions importable from this package
__all__ = [
'csa_compress_and_store', 'hca_compress_and_store',
'csa_compress_tail', 'hca_compress_tail',
]