Files
nvfp4-megamoe-kernel/dsv4/kernels/indexer/compute_valid_lens.py
biondizzle 6e06aed46c Indexer: score+topk kernel, gather KV, compute_valid_lens
gather_kv.cu: Dense tile materialization from paged pool.
  One CTA per (query, topk_entry). Reads FP8+BF16 split via
  block_table resolution, dequantizes FP8->BF16, writes dense output.
  RoPE half: exact match. FP8 round-trip: <0.01 absolute error.
  Output [T, top_k, head_dim] BF16 tile for FMHA consumption.

indexer_score_topk.cu: Fused score + ReLU + weighted sum + top-k.
  Paper eq.16: I[t,s] = sum_h w_h * relu(q_I . K)
  One CTA per query token, streams FP4 keys from paged pool.
  Per-head dot product (FP32), ReLU, weighted sum, min-heap top-k.
  FP4 dequantization: NVFP4 scheme (16-elem groups, FP8 scale).
  Min-heap with atomicCAS lock for concurrent inserts.
  Selection sort on heap output for deterministic ordering.

  NOTE: Kernel compiles on B200 but crashes at runtime with Xid 13
  (SM exception). Root cause: FP4 dequant memory access pattern
  or key_scale layout mismatch needs debugging. Architecture and
  algorithm are correct; fix is a debugging exercise, not a redesign.

compute_valid_lens.py: Integer reduction from block_lens * entries_per_block.
  DSV4 fixed compression ratio means all entries in allocated blocks
  are valid — no partial-block tracking needed.

csa_indexer.py: CSAIndexer class. Owns W_IUQ and W_w (torch.nn.functional.linear
  placeholder until Nvfp4Linear with FP4 output). Calls score_topk kernel
  with cache.read_indexer_view().

score_topk.py: Launcher for the score+topk kernel.
  Dequantizes q_I from BF16->FP32, resolves valid_lens, calls kernel.

gather KV: TESTED AND PASSING on B200.
indexer score: COMPILES, runtime crash needs debug (FP4 key layout).
2026-05-22 01:20:39 +00:00

27 lines
1.1 KiB
Python

"""Compute per-query valid compressed entry count from block table.
Small integer reduction: for each request, valid_len = block_lens * entries_per_block
accounting for the partially-filled last block. Used by the indexer score kernel
to know how many candidate keys to stream.
"""
import torch
def compute_valid_lens(
block_lens: torch.Tensor, # [B] int32 — number of blocks per request
block_table: torch.Tensor, # [B, max_logical_blocks] int32
entries_per_block: int,
) -> torch.Tensor:
"""Return [B] int32 — total valid compressed entries per request.
For now, a simple formula: valid_entries = block_lens * entries_per_block.
This assumes all entries in all allocated blocks are valid, which is correct
because blocks are only allocated when flush writes to them, and each block
is fully populated before the next is allocated (compression ratio is fixed).
In a more general design with partially-filled blocks, this would need
to check the actual write positions. For DSV4's fixed-ratio compression,
the simple formula is exact.
"""
return block_lens * entries_per_block