- indexer/__init__.py: compute_index_scores_topk now calls run_indexer_score_topk with proper tensor reshaping - compressor/__init__.py: added torch import, fixed csa_compress_tail and hca_compress_tail imports for flush.py - Full flush pipeline now importable end-to-end
57 lines
2.1 KiB
Python
57 lines
2.1 KiB
Python
"""CSA/HCA compressor — Python API bridge.
|
|
|
|
Wraps the compression functions with the interface that
|
|
AttentionSubBlock and flush.py expect.
|
|
|
|
The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
|
|
to produce compressed KV entries. The compressed entries are then written to the
|
|
paged pool by the flush_write kernel.
|
|
"""
|
|
import torch
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from dsv4.cache.handle import LayerCacheHandle
|
|
|
|
from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail
|
|
|
|
|
|
def csa_compress_and_store(
|
|
kv_raw: torch.Tensor, # (T, head_dim) BF16 — current KV (goes to tail)
|
|
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
|
|
) -> None:
|
|
"""CSA: compress KV entries and store into the classical paged cache.
|
|
|
|
Steps:
|
|
1. Check if tail has enough entries (tail_len >= m=4)
|
|
2. If so, run compression (csa_compress_tail)
|
|
3. Write compressed output to paged pool via flush_write
|
|
4. Update tail buffer (a-stream becomes next b-stream)
|
|
"""
|
|
from dsv4.kernels.cuda.flush_write import flush_write_csa_cuda
|
|
# NOTE: This function is called from AttentionSubBlock.forward, which
|
|
# writes the raw KV to the tail buffer first (via cache.write_swa).
|
|
# The actual compression + flush happens when tail_len >= m.
|
|
# For now, the write_swa call handles the tail buffer write.
|
|
# The flush is triggered separately by the flush pipeline.
|
|
# See dsv4/cache/flush.py for the flush orchestration.
|
|
pass # Compression is handled by flush.py, not directly here
|
|
|
|
|
|
def hca_compress_and_store(
|
|
kv_raw: torch.Tensor, # (T, head_dim) BF16
|
|
cache: "LayerCacheHandle", # reads tail, writes compressed to paged pool
|
|
) -> None:
|
|
"""HCA: compress KV entries and store into the classical paged cache.
|
|
|
|
Same structure as CSA but no b-stream, no overlap, m'=128.
|
|
"""
|
|
pass # See flush.py
|
|
|
|
|
|
# Make compress_tail functions importable from this package
|
|
__all__ = [
|
|
'csa_compress_and_store', 'hca_compress_and_store',
|
|
'csa_compress_tail', 'hca_compress_tail',
|
|
]
|