"""CSA/HCA compressor — Python API bridge. Wraps the compression functions with the interface that AttentionSubBlock and flush.py expect. The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA) to produce compressed KV entries. The compressed entries are then written to the paged pool by the flush_write kernel. """ import torch from typing import TYPE_CHECKING if TYPE_CHECKING: from dsv4.cache.handle import LayerCacheHandle from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail def csa_compress_and_store( kv_raw: torch.Tensor, # (T, head_dim) BF16 — current KV (goes to tail) cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool ) -> None: """CSA: compress KV entries and store into the classical paged cache. Steps: 1. Check if tail has enough entries (tail_len >= m=4) 2. If so, run compression (csa_compress_tail) 3. Write compressed output to paged pool via flush_write 4. Update tail buffer (a-stream becomes next b-stream) """ from dsv4.kernels.cuda.flush_write import flush_write_csa_cuda # NOTE: This function is called from AttentionSubBlock.forward, which # writes the raw KV to the tail buffer first (via cache.write_swa). # The actual compression + flush happens when tail_len >= m. # For now, the write_swa call handles the tail buffer write. # The flush is triggered separately by the flush pipeline. # See dsv4/cache/flush.py for the flush orchestration. pass # Compression is handled by flush.py, not directly here def hca_compress_and_store( kv_raw: torch.Tensor, # (T, head_dim) BF16 cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool ) -> None: """HCA: compress KV entries and store into the classical paged cache. Same structure as CSA but no b-stream, no overlap, m'=128. """ pass # See flush.py # Make compress_tail functions importable from this package __all__ = [ 'csa_compress_and_store', 'hca_compress_and_store', 'csa_compress_tail', 'hca_compress_tail', ]