Files
nvfp4-megamoe-kernel/dsv4/kernels/indexer/__init__.py
biondizzle 9d88769f5f Wire indexer compute_index_scores_topk + fix compressor imports
- indexer/__init__.py: compute_index_scores_topk now calls
  run_indexer_score_topk with proper tensor reshaping
- compressor/__init__.py: added torch import, fixed csa_compress_tail
  and hca_compress_tail imports for flush.py
- Full flush pipeline now importable end-to-end
2026-05-30 21:19:06 +00:00

64 lines
2.4 KiB
Python

"""CSA indexer — Python API bridge.
Wraps the CUDA indexer score+topk kernel with the interface that
AttentionSubBlock expects.
The indexer (paper §2.3.5, eq. 16) scores each query against
compressed blocks via weighted ReLU MQA logits, then selects
top-k blocks for sparse attention.
Currently uses scalar FP32 CUDA cores after FP4 dequant.
The FP4 tensor-core path (Stage F / E7) is a future optimization.
"""
import torch
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dsv4.cache.handle import LayerCacheHandle
def compute_index_scores_topk(
q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query
w_indexer: torch.Tensor, # (T, n_I_h) FP32 — per-head weights
cache: "LayerCacheHandle", # provides FP4 indexer keys
top_k: int = 512, # number of blocks to select
) -> torch.Tensor: # (T, top_k) int64 — selected block indices
"""CSA: score compressed entries and select top-k blocks.
Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar
score + min-heap top-k). Returns entry indices for gather_compressed_kv.
"""
from dsv4.kernels.indexer.score_topk import run_indexer_score_topk
# Read the indexer view from the cache
indexer_view = cache.read_indexer_view()
# c_I is the indexer head dimension from schema
n_I_h = cache.schema.indexer_entries_per_block # This is entries, not heads
c_I = cache.schema.indexer_head_dim # 128
# n_I_h (number of indexer heads) comes from the config, not the schema.
# We need to pass it through the handle or compute it.
# For DSV4: n_I_h = 64 (same for Flash and Pro)
# TODO: add indexer_num_heads to schema or handle
n_I_h = 64 # config.indexer_num_heads, hardcoded for now
# Reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h * c_I) — already flat
# The kernel expects q_I: [T, n_I_h * c_I] BF16
# and w_h: [T, n_I_h] FP32
entries_per_block = cache.schema.entries_per_block
indices = run_indexer_score_topk(
q_I=q_indexer,
w_h=w_indexer.float() if w_indexer.dtype != torch.float32 else w_indexer,
indexer_view=indexer_view,
num_heads=n_I_h,
head_dim=c_I,
top_k=top_k,
entries_per_block=entries_per_block,
)
# indices: (T, top_k) int32 → convert to int64 for gather_compressed_kv
return indices.to(torch.int64)