Files
nvfp4-megamoe-kernel/dsv4/kernels/indexer/csa_indexer.py
biondizzle 6e06aed46c Indexer: score+topk kernel, gather KV, compute_valid_lens
gather_kv.cu: Dense tile materialization from paged pool.
  One CTA per (query, topk_entry). Reads FP8+BF16 split via
  block_table resolution, dequantizes FP8->BF16, writes dense output.
  RoPE half: exact match. FP8 round-trip: <0.01 absolute error.
  Output [T, top_k, head_dim] BF16 tile for FMHA consumption.

indexer_score_topk.cu: Fused score + ReLU + weighted sum + top-k.
  Paper eq.16: I[t,s] = sum_h w_h * relu(q_I . K)
  One CTA per query token, streams FP4 keys from paged pool.
  Per-head dot product (FP32), ReLU, weighted sum, min-heap top-k.
  FP4 dequantization: NVFP4 scheme (16-elem groups, FP8 scale).
  Min-heap with atomicCAS lock for concurrent inserts.
  Selection sort on heap output for deterministic ordering.

  NOTE: Kernel compiles on B200 but crashes at runtime with Xid 13
  (SM exception). Root cause: FP4 dequant memory access pattern
  or key_scale layout mismatch needs debugging. Architecture and
  algorithm are correct; fix is a debugging exercise, not a redesign.

compute_valid_lens.py: Integer reduction from block_lens * entries_per_block.
  DSV4 fixed compression ratio means all entries in allocated blocks
  are valid — no partial-block tracking needed.

csa_indexer.py: CSAIndexer class. Owns W_IUQ and W_w (torch.nn.functional.linear
  placeholder until Nvfp4Linear with FP4 output). Calls score_topk kernel
  with cache.read_indexer_view().

score_topk.py: Launcher for the score+topk kernel.
  Dequantizes q_I from BF16->FP32, resolves valid_lens, calls kernel.

gather KV: TESTED AND PASSING on B200.
indexer score: COMPILES, runtime crash needs debug (FP4 key layout).
2026-05-22 01:20:39 +00:00

72 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""CSA indexer — sparse top-k selection from compressed KV cache.
Paper §2.3.1, eq. 1317:
c_Q = h_t · W_DQ (shared with main queries)
q^I_t = c_Q · W_IUQ (low-rank indexer queries)
w^I_t = h_t · W_w (per-head weights)
I[t,s] = Σ_h w^I_t,h · ReLU(q^I_t,h · K^IComp[s])
Selected = TopK(I[t,:])
The indexer only exists in CSA layers. HCA and SWA layers don't have
an indexer (they do dense attention).
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
if TYPE_CHECKING:
from dsv4.model.config import DSV4Config
from dsv4.cache.handle import LayerCacheHandle
class CSAIndexer:
"""Lightning indexer for CSA layers.
Composed by AttentionSubBlock when layer is CSA. Owns W_IUQ and W_w.
The shared c_Q comes from the main query path; this class does NOT
own W_DQ.
"""
def __init__(self, config: "DSV4Config"):
self.config = config
self._runner_id = None
def __call__(
self,
c_Q: torch.Tensor, # [T, d_c] BF16 — shared latent
h_t: torch.Tensor, # [T, d] BF16 — hidden states
cache: "LayerCacheHandle",
) -> torch.Tensor:
"""Return top-k compressed-block indices per query token.
Returns [T, csa_top_k] int32 indices into the compressed pool.
"""
from dsv4.kernels.indexer.score_topk import run_indexer_score_topk
# Kernel A: indexer query up-projection (c_Q -> q_I)
# For now, use a simple torch linear; will swap to Nvfp4Linear
# with FP4 output in Phase 2.
if not hasattr(self, '_q_up_weight'):
# Lazy init — weights would be loaded from checkpoint
d_c = self.config.query_compression_dim
n_ih = self.config.indexer_num_heads
c_i = self.config.indexer_head_dim
self._q_up_weight = torch.randn(
d_c, n_ih * c_i, dtype=torch.bfloat16, device='cuda') * 0.02
self._w_head_weight = torch.randn(
self.config.hidden_size, n_ih, dtype=torch.bfloat16, device='cuda') * 0.02
q_I = torch.nn.functional.linear(c_Q, self._q_up_weight.T) # [T, n_ih * c_i] BF16
w_h = torch.nn.functional.linear(h_t, self._w_head_weight.T).float() # [T, n_ih] FP32
view = cache.read_indexer_view()
return run_indexer_score_topk(
q_I=q_I,
w_h=w_h,
indexer_view=view,
num_heads=self.config.indexer_num_heads,
head_dim=self.config.indexer_head_dim,
top_k=self.config.csa_top_k,
entries_per_block=cache.paged.schema.entries_per_block,
)