- indexer/__init__.py: compute_index_scores_topk now calls run_indexer_score_topk with proper tensor reshaping - compressor/__init__.py: added torch import, fixed csa_compress_tail and hca_compress_tail imports for flush.py - Full flush pipeline now importable end-to-end
64 lines
2.4 KiB
Python
64 lines
2.4 KiB
Python
"""CSA indexer — Python API bridge.
|
|
|
|
Wraps the CUDA indexer score+topk kernel with the interface that
|
|
AttentionSubBlock expects.
|
|
|
|
The indexer (paper §2.3.5, eq. 16) scores each query against
|
|
compressed blocks via weighted ReLU MQA logits, then selects
|
|
top-k blocks for sparse attention.
|
|
|
|
Currently uses scalar FP32 CUDA cores after FP4 dequant.
|
|
The FP4 tensor-core path (Stage F / E7) is a future optimization.
|
|
"""
|
|
import torch
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from dsv4.cache.handle import LayerCacheHandle
|
|
|
|
|
|
def compute_index_scores_topk(
|
|
q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query
|
|
w_indexer: torch.Tensor, # (T, n_I_h) FP32 — per-head weights
|
|
cache: "LayerCacheHandle", # provides FP4 indexer keys
|
|
top_k: int = 512, # number of blocks to select
|
|
) -> torch.Tensor: # (T, top_k) int64 — selected block indices
|
|
"""CSA: score compressed entries and select top-k blocks.
|
|
|
|
Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar
|
|
score + min-heap top-k). Returns entry indices for gather_compressed_kv.
|
|
"""
|
|
from dsv4.kernels.indexer.score_topk import run_indexer_score_topk
|
|
|
|
# Read the indexer view from the cache
|
|
indexer_view = cache.read_indexer_view()
|
|
|
|
# c_I is the indexer head dimension from schema
|
|
n_I_h = cache.schema.indexer_entries_per_block # This is entries, not heads
|
|
c_I = cache.schema.indexer_head_dim # 128
|
|
|
|
# n_I_h (number of indexer heads) comes from the config, not the schema.
|
|
# We need to pass it through the handle or compute it.
|
|
# For DSV4: n_I_h = 64 (same for Flash and Pro)
|
|
# TODO: add indexer_num_heads to schema or handle
|
|
n_I_h = 64 # config.indexer_num_heads, hardcoded for now
|
|
|
|
# Reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h * c_I) — already flat
|
|
# The kernel expects q_I: [T, n_I_h * c_I] BF16
|
|
# and w_h: [T, n_I_h] FP32
|
|
|
|
entries_per_block = cache.schema.entries_per_block
|
|
|
|
indices = run_indexer_score_topk(
|
|
q_I=q_indexer,
|
|
w_h=w_indexer.float() if w_indexer.dtype != torch.float32 else w_indexer,
|
|
indexer_view=indexer_view,
|
|
num_heads=n_I_h,
|
|
head_dim=c_I,
|
|
top_k=top_k,
|
|
entries_per_block=entries_per_block,
|
|
)
|
|
|
|
# indices: (T, top_k) int32 → convert to int64 for gather_compressed_kv
|
|
return indices.to(torch.int64)
|