Wire indexer compute_index_scores_topk + fix compressor imports

- indexer/__init__.py: compute_index_scores_topk now calls
  run_indexer_score_topk with proper tensor reshaping
- compressor/__init__.py: added torch import, fixed csa_compress_tail
  and hca_compress_tail imports for flush.py
- Full flush pipeline now importable end-to-end
This commit is contained in:
2026-05-30 21:19:06 +00:00
parent daf84524ac
commit 9d88769f5f
2 changed files with 35 additions and 19 deletions

View File

@@ -7,6 +7,12 @@ The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
to produce compressed KV entries. The compressed entries are then written to the
paged pool by the flush_write kernel.
"""
import torch
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dsv4.cache.handle import LayerCacheHandle
from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail

View File

@@ -7,8 +7,8 @@ The indexer (paper §2.3.5, eq. 16) scores each query against
compressed blocks via weighted ReLU MQA logits, then selects
top-k blocks for sparse attention.
Currently uses scalar FP32 CUDA cores. The FP4 tensor-core path
(Stage F / E7) is a future optimization.
Currently uses scalar FP32 CUDA cores after FP4 dequant.
The FP4 tensor-core path (Stage F / E7) is a future optimization.
"""
import torch
from typing import TYPE_CHECKING
@@ -19,35 +19,45 @@ if TYPE_CHECKING:
def compute_index_scores_topk(
q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query
w_indexer: torch.Tensor, # (T, n_I_h) BF16 — per-head weights
w_indexer: torch.Tensor, # (T, n_I_h) FP32 — per-head weights
cache: "LayerCacheHandle", # provides FP4 indexer keys
top_k: int = 512, # number of blocks to select
) -> torch.Tensor: # (T, top_k) int64 — selected block indices
"""CSA: score compressed entries and select top-k blocks.
Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar
score + min-heap top-k). Returns block indices for gather_compressed_kv.
score + min-heap top-k). Returns entry indices for gather_compressed_kv.
"""
from dsv4.kernels.indexer.csa_indexer import CSAIndexer
from dsv4.kernels.indexer.score_topk import run_indexer_score_topk
# Read the indexer view from the cache
indexer_view = cache.read_indexer_view()
# The CSAIndexer expects:
# - q_up: (T, n_I_h, c_I) BF16
# - w_head: (T, n_I_h) BF16
# - keys_fp4, scale, global_scale, block_table, block_lens from indexer_view
# c_I is the indexer head dimension from schema
n_I_h = cache.schema.indexer_entries_per_block # This is entries, not heads
c_I = cache.schema.indexer_head_dim # 128
n_I_h = cache.schema.indexer_head_dim # This is actually indexer_num_heads
# Wait — indexer_head_dim is c_I (128), not n_I_h (64)
# Need to check schema more carefully
# n_I_h (number of indexer heads) comes from the config, not the schema.
# We need to pass it through the handle or compute it.
# For DSV4: n_I_h = 64 (same for Flash and Pro)
# TODO: add indexer_num_heads to schema or handle
n_I_h = 64 # config.indexer_num_heads, hardcoded for now
# For now, reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h, c_I)
# n_I_h comes from the config, not the schema
# TODO: add indexer_num_heads to LayerCacheSchema or pass through handle
# Reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h * c_I) — already flat
# The kernel expects q_I: [T, n_I_h * c_I] BF16
# and w_h: [T, n_I_h] FP32
raise NotImplementedError(
"compute_index_scores_topk: needs proper wiring of CSAIndexer to "
"cache handle's IndexerView. The indexer_score_topk kernel runs on B200. "
"The gap is: reshape q_indexer → create CSAIndexer → call run_indexer_score_topk → return indices."
entries_per_block = cache.schema.entries_per_block
indices = run_indexer_score_topk(
q_I=q_indexer,
w_h=w_indexer.float() if w_indexer.dtype != torch.float32 else w_indexer,
indexer_view=indexer_view,
num_heads=n_I_h,
head_dim=c_I,
top_k=top_k,
entries_per_block=entries_per_block,
)
# indices: (T, top_k) int32 → convert to int64 for gather_compressed_kv
return indices.to(torch.int64)