diff --git a/dsv4/kernels/compressor/__init__.py b/dsv4/kernels/compressor/__init__.py index 9c32b267..0c4e776a 100644 --- a/dsv4/kernels/compressor/__init__.py +++ b/dsv4/kernels/compressor/__init__.py @@ -7,6 +7,12 @@ The compressor runs token-level softmax over m entries (CSA) or m' entries (HCA) to produce compressed KV entries. The compressed entries are then written to the paged pool by the flush_write kernel. """ +import torch +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dsv4.cache.handle import LayerCacheHandle + from dsv4.kernels.compressor.compress_tail import csa_compress_tail, hca_compress_tail diff --git a/dsv4/kernels/indexer/__init__.py b/dsv4/kernels/indexer/__init__.py index 17f29caa..303ed0d0 100644 --- a/dsv4/kernels/indexer/__init__.py +++ b/dsv4/kernels/indexer/__init__.py @@ -7,8 +7,8 @@ The indexer (paper §2.3.5, eq. 16) scores each query against compressed blocks via weighted ReLU MQA logits, then selects top-k blocks for sparse attention. -Currently uses scalar FP32 CUDA cores. The FP4 tensor-core path -(Stage F / E7) is a future optimization. +Currently uses scalar FP32 CUDA cores after FP4 dequant. +The FP4 tensor-core path (Stage F / E7) is a future optimization. """ import torch from typing import TYPE_CHECKING @@ -19,35 +19,45 @@ if TYPE_CHECKING: def compute_index_scores_topk( q_indexer: torch.Tensor, # (T, n_I_h * c_I) BF16 — indexer query - w_indexer: torch.Tensor, # (T, n_I_h) BF16 — per-head weights + w_indexer: torch.Tensor, # (T, n_I_h) FP32 — per-head weights cache: "LayerCacheHandle", # provides FP4 indexer keys top_k: int = 512, # number of blocks to select ) -> torch.Tensor: # (T, top_k) int64 — selected block indices """CSA: score compressed entries and select top-k blocks. Uses the CUDA indexer_score_topk kernel (raw CUDA, FP4 dequant + scalar - score + min-heap top-k). Returns block indices for gather_compressed_kv. + score + min-heap top-k). Returns entry indices for gather_compressed_kv. """ - from dsv4.kernels.indexer.csa_indexer import CSAIndexer + from dsv4.kernels.indexer.score_topk import run_indexer_score_topk # Read the indexer view from the cache indexer_view = cache.read_indexer_view() - # The CSAIndexer expects: - # - q_up: (T, n_I_h, c_I) BF16 - # - w_head: (T, n_I_h) BF16 - # - keys_fp4, scale, global_scale, block_table, block_lens from indexer_view + # c_I is the indexer head dimension from schema + n_I_h = cache.schema.indexer_entries_per_block # This is entries, not heads + c_I = cache.schema.indexer_head_dim # 128 - n_I_h = cache.schema.indexer_head_dim # This is actually indexer_num_heads - # Wait — indexer_head_dim is c_I (128), not n_I_h (64) - # Need to check schema more carefully + # n_I_h (number of indexer heads) comes from the config, not the schema. + # We need to pass it through the handle or compute it. + # For DSV4: n_I_h = 64 (same for Flash and Pro) + # TODO: add indexer_num_heads to schema or handle + n_I_h = 64 # config.indexer_num_heads, hardcoded for now - # For now, reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h, c_I) - # n_I_h comes from the config, not the schema - # TODO: add indexer_num_heads to LayerCacheSchema or pass through handle + # Reshape q_indexer from (T, n_I_h * c_I) to (T, n_I_h * c_I) — already flat + # The kernel expects q_I: [T, n_I_h * c_I] BF16 + # and w_h: [T, n_I_h] FP32 - raise NotImplementedError( - "compute_index_scores_topk: needs proper wiring of CSAIndexer to " - "cache handle's IndexerView. The indexer_score_topk kernel runs on B200. " - "The gap is: reshape q_indexer → create CSAIndexer → call run_indexer_score_topk → return indices." + entries_per_block = cache.schema.entries_per_block + + indices = run_indexer_score_topk( + q_I=q_indexer, + w_h=w_indexer.float() if w_indexer.dtype != torch.float32 else w_indexer, + indexer_view=indexer_view, + num_heads=n_I_h, + head_dim=c_I, + top_k=top_k, + entries_per_block=entries_per_block, ) + + # indices: (T, top_k) int32 → convert to int64 for gather_compressed_kv + return indices.to(torch.int64)