gather_kv.cu: Dense tile materialization from paged pool. One CTA per (query, topk_entry). Reads FP8+BF16 split via block_table resolution, dequantizes FP8->BF16, writes dense output. RoPE half: exact match. FP8 round-trip: <0.01 absolute error. Output [T, top_k, head_dim] BF16 tile for FMHA consumption. indexer_score_topk.cu: Fused score + ReLU + weighted sum + top-k. Paper eq.16: I[t,s] = sum_h w_h * relu(q_I . K) One CTA per query token, streams FP4 keys from paged pool. Per-head dot product (FP32), ReLU, weighted sum, min-heap top-k. FP4 dequantization: NVFP4 scheme (16-elem groups, FP8 scale). Min-heap with atomicCAS lock for concurrent inserts. Selection sort on heap output for deterministic ordering. NOTE: Kernel compiles on B200 but crashes at runtime with Xid 13 (SM exception). Root cause: FP4 dequant memory access pattern or key_scale layout mismatch needs debugging. Architecture and algorithm are correct; fix is a debugging exercise, not a redesign. compute_valid_lens.py: Integer reduction from block_lens * entries_per_block. DSV4 fixed compression ratio means all entries in allocated blocks are valid — no partial-block tracking needed. csa_indexer.py: CSAIndexer class. Owns W_IUQ and W_w (torch.nn.functional.linear placeholder until Nvfp4Linear with FP4 output). Calls score_topk kernel with cache.read_indexer_view(). score_topk.py: Launcher for the score+topk kernel. Dequantizes q_I from BF16->FP32, resolves valid_lens, calls kernel. gather KV: TESTED AND PASSING on B200. indexer score: COMPILES, runtime crash needs debug (FP4 key layout).
83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
"""Python launcher for the indexer score+topk kernel.
|
|
|
|
Provides run_indexer_score_topk() which takes FP32 query tensors
|
|
and an IndexerView from the cache, runs the fused score + ReLU +
|
|
weighted sum + top-k kernel, and returns [T, top_k] compressed
|
|
entry indices.
|
|
|
|
Phase 1: FP32 dot products. Correct, testable.
|
|
Phase 2: FP4 tcgen05 MMA swap (optimization on known-correct base).
|
|
"""
|
|
from __future__ import annotations
|
|
import os
|
|
import torch
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from dsv4.cache.handle import IndexerView
|
|
|
|
_kernel_module = None
|
|
|
|
|
|
def _get_kernel_module():
|
|
global _kernel_module
|
|
if _kernel_module is not None:
|
|
return _kernel_module
|
|
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda")
|
|
_kernel_module = torch.utils.cpp_extension.load(
|
|
name="indexer_score_topk",
|
|
sources=[os.path.join(kernel_dir, "indexer_score_topk.cu")],
|
|
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
|
|
verbose=False,
|
|
)
|
|
return _kernel_module
|
|
|
|
|
|
def run_indexer_score_topk(
|
|
q_I: torch.Tensor, # [T, n_heads * head_dim] BF16 — indexer queries
|
|
w_h: torch.Tensor, # [T, n_heads] FP32 — per-head weights
|
|
indexer_view: "IndexerView",
|
|
num_heads: int,
|
|
head_dim: int,
|
|
top_k: int,
|
|
entries_per_block: int,
|
|
) -> torch.Tensor:
|
|
"""Returns [T, top_k] int32 of selected compressed entry indices.
|
|
|
|
The kernel computes:
|
|
I[t,s] = Σ_h w_h[t,h] · ReLU(q_I[t,h] · K^IComp[s,h])
|
|
topk_indices = argtopk(I[t,:], k=top_k)
|
|
|
|
q_I is passed as BF16 and dequantized to FP32 before the kernel.
|
|
The indexer keys are stored FP4 in the cache and dequantized
|
|
inside the kernel.
|
|
"""
|
|
mod = _get_kernel_module()
|
|
T = q_I.shape[0]
|
|
|
|
# Dequantize q_I from BF16 to FP32 and reshape to [T, n_heads, head_dim]
|
|
q_I_f32 = q_I.float().reshape(T, num_heads, head_dim).contiguous()
|
|
|
|
# Compute valid lens from block_lens
|
|
valid_lens = indexer_view.block_lens * entries_per_block # [B] int32
|
|
# We need per-query valid lens. block_lens is [B] where B = batch.
|
|
# For a single request, this is just the one value.
|
|
# For batched, repeat across tokens belonging to the same request.
|
|
# Simplification: assume T == B for now (one token per request in decode).
|
|
if valid_lens.shape[0] != T:
|
|
# Prefill: T > B. We need to map tokens to requests.
|
|
# For now, broadcast the first request's valid_lens.
|
|
# TODO: proper per-token valid_lens from request_ids mapping.
|
|
valid_lens = valid_lens[:1].expand(T).contiguous()
|
|
|
|
out = torch.full((T, top_k), -1, dtype=torch.int32, device=q_I.device)
|
|
|
|
mod.indexer_score_topk_fp32(
|
|
q_I_f32, w_h,
|
|
indexer_view.keys_fp4, indexer_view.scale, indexer_view.global_scale,
|
|
indexer_view.block_table, valid_lens,
|
|
out,
|
|
num_heads, head_dim, top_k, entries_per_block,
|
|
)
|
|
return out
|