Wire indexer compute_index_scores_topk + fix compressor imports
- indexer/__init__.py: compute_index_scores_topk now calls run_indexer_score_topk with proper tensor reshaping - compressor/__init__.py: added torch import, fixed csa_compress_tail and hca_compress_tail imports for flush.py - Full flush pipeline now importable end-to-end
This commit is contained in:
@@ -7,6 +7,12 @@ The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA)
|
||||
to produce compressed KV entries. The compressed entries are then written to the
|
||||
paged pool by the flush_write kernel.
|
||||
"""
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dsv4.cache.handle import LayerCacheHandle
|
||||
|
||||
from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail
|
||||
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ The indexer (paper §2.3.5, eq. 16) scores each query against
|
||||
compressed blocks via weighted ReLU MQA logits, then selects
|
||||
top-k blocks for sparse attention.
|
||||
|
||||
Currently uses scalar FP32 CUDA cores. The FP4 tensor-core path
|
||||
(Stage F / E7) is a future optimization.
|
||||
Currently uses scalar FP32 CUDA cores after FP4 dequant.
|
||||
The FP4 tensor-core path (Stage F / E7) is a future optimization.
|
||||
"""
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -19,35 +19,45 @@ if TYPE_CHECKING:
|
||||
|
||||
def compute_index_scores_topk(
|
||||
q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query
|
||||
w_indexer: torch.Tensor, # (T, n_I_h) BF16 — per-head weights
|
||||
w_indexer: torch.Tensor, # (T, n_I_h) FP32 — per-head weights
|
||||
cache: "LayerCacheHandle", # provides FP4 indexer keys
|
||||
top_k: int = 512, # number of blocks to select
|
||||
) -> torch.Tensor: # (T, top_k) int64 — selected block indices
|
||||
"""CSA: score compressed entries and select top-k blocks.
|
||||
|
||||
Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar
|
||||
score + min-heap top-k). Returns block indices for gather_compressed_kv.
|
||||
score + min-heap top-k). Returns entry indices for gather_compressed_kv.
|
||||
"""
|
||||
from dsv4.kernels.indexer.csa_indexer import CSAIndexer
|
||||
from dsv4.kernels.indexer.score_topk import run_indexer_score_topk
|
||||
|
||||
# Read the indexer view from the cache
|
||||
indexer_view = cache.read_indexer_view()
|
||||
|
||||
# The CSAIndexer expects:
|
||||
# - q_up: (T, n_I_h, c_I) BF16
|
||||
# - w_head: (T, n_I_h) BF16
|
||||
# - keys_fp4, scale, global_scale, block_table, block_lens from indexer_view
|
||||
# c_I is the indexer head dimension from schema
|
||||
n_I_h = cache.schema.indexer_entries_per_block # This is entries, not heads
|
||||
c_I = cache.schema.indexer_head_dim # 128
|
||||
|
||||
n_I_h = cache.schema.indexer_head_dim # This is actually indexer_num_heads
|
||||
# Wait — indexer_head_dim is c_I (128), not n_I_h (64)
|
||||
# Need to check schema more carefully
|
||||
# n_I_h (number of indexer heads) comes from the config, not the schema.
|
||||
# We need to pass it through the handle or compute it.
|
||||
# For DSV4: n_I_h = 64 (same for Flash and Pro)
|
||||
# TODO: add indexer_num_heads to schema or handle
|
||||
n_I_h = 64 # config.indexer_num_heads, hardcoded for now
|
||||
|
||||
# For now, reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h, c_I)
|
||||
# n_I_h comes from the config, not the schema
|
||||
# TODO: add indexer_num_heads to LayerCacheSchema or pass through handle
|
||||
# Reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h * c_I) — already flat
|
||||
# The kernel expects q_I: [T, n_I_h * c_I] BF16
|
||||
# and w_h: [T, n_I_h] FP32
|
||||
|
||||
raise NotImplementedError(
|
||||
"compute_index_scores_topk: needs proper wiring of CSAIndexer to "
|
||||
"cache handle's IndexerView. The indexer_score_topk kernel runs on B200. "
|
||||
"The gap is: reshape q_indexer → create CSAIndexer → call run_indexer_score_topk → return indices."
|
||||
entries_per_block = cache.schema.entries_per_block
|
||||
|
||||
indices = run_indexer_score_topk(
|
||||
q_I=q_indexer,
|
||||
w_h=w_indexer.float() if w_indexer.dtype != torch.float32 else w_indexer,
|
||||
indexer_view=indexer_view,
|
||||
num_heads=n_I_h,
|
||||
head_dim=c_I,
|
||||
top_k=top_k,
|
||||
entries_per_block=entries_per_block,
|
||||
)
|
||||
|
||||
# indices: (T, top_k) int32 → convert to int64 for gather_compressed_kv
|
||||
return indices.to(torch.int64)
|
||||
|
||||
Reference in New Issue
Block a user