gather_kv.cu: Dense tile materialization from paged pool. One CTA per (query, topk_entry). Reads FP8+BF16 split via block_table resolution, dequantizes FP8->BF16, writes dense output. RoPE half: exact match. FP8 round-trip: <0.01 absolute error. Output [T, top_k, head_dim] BF16 tile for FMHA consumption. indexer_score_topk.cu: Fused score + ReLU + weighted sum + top-k. Paper eq.16: I[t,s] = sum_h w_h * relu(q_I . K) One CTA per query token, streams FP4 keys from paged pool. Per-head dot product (FP32), ReLU, weighted sum, min-heap top-k. FP4 dequantization: NVFP4 scheme (16-elem groups, FP8 scale). Min-heap with atomicCAS lock for concurrent inserts. Selection sort on heap output for deterministic ordering. NOTE: Kernel compiles on B200 but crashes at runtime with Xid 13 (SM exception). Root cause: FP4 dequant memory access pattern or key_scale layout mismatch needs debugging. Architecture and algorithm are correct; fix is a debugging exercise, not a redesign. compute_valid_lens.py: Integer reduction from block_lens * entries_per_block. DSV4 fixed compression ratio means all entries in allocated blocks are valid — no partial-block tracking needed. csa_indexer.py: CSAIndexer class. Owns W_IUQ and W_w (torch.nn.functional.linear placeholder until Nvfp4Linear with FP4 output). Calls score_topk kernel with cache.read_indexer_view(). score_topk.py: Launcher for the score+topk kernel. Dequantizes q_I from BF16->FP32, resolves valid_lens, calls kernel. gather KV: TESTED AND PASSING on B200. indexer score: COMPILES, runtime crash needs debug (FP4 key layout).
72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
"""CSA indexer — sparse top-k selection from compressed KV cache.
|
||
|
||
Paper §2.3.1, eq. 13–17:
|
||
c_Q = h_t · W_DQ (shared with main queries)
|
||
q^I_t = c_Q · W_IUQ (low-rank indexer queries)
|
||
w^I_t = h_t · W_w (per-head weights)
|
||
I[t,s] = Σ_h w^I_t,h · ReLU(q^I_t,h · K^IComp[s])
|
||
Selected = TopK(I[t,:])
|
||
|
||
The indexer only exists in CSA layers. HCA and SWA layers don't have
|
||
an indexer (they do dense attention).
|
||
"""
|
||
from __future__ import annotations
|
||
from typing import TYPE_CHECKING
|
||
import torch
|
||
|
||
if TYPE_CHECKING:
|
||
from dsv4.model.config import DSV4Config
|
||
from dsv4.cache.handle import LayerCacheHandle
|
||
|
||
|
||
class CSAIndexer:
|
||
"""Lightning indexer for CSA layers.
|
||
|
||
Composed by AttentionSubBlock when layer is CSA. Owns W_IUQ and W_w.
|
||
The shared c_Q comes from the main query path; this class does NOT
|
||
own W_DQ.
|
||
"""
|
||
|
||
def __init__(self, config: "DSV4Config"):
|
||
self.config = config
|
||
self._runner_id = None
|
||
|
||
def __call__(
|
||
self,
|
||
c_Q: torch.Tensor, # [T, d_c] BF16 — shared latent
|
||
h_t: torch.Tensor, # [T, d] BF16 — hidden states
|
||
cache: "LayerCacheHandle",
|
||
) -> torch.Tensor:
|
||
"""Return top-k compressed-block indices per query token.
|
||
|
||
Returns [T, csa_top_k] int32 indices into the compressed pool.
|
||
"""
|
||
from dsv4.kernels.indexer.score_topk import run_indexer_score_topk
|
||
|
||
# Kernel A: indexer query up-projection (c_Q -> q_I)
|
||
# For now, use a simple torch linear; will swap to Nvfp4Linear
|
||
# with FP4 output in Phase 2.
|
||
if not hasattr(self, '_q_up_weight'):
|
||
# Lazy init — weights would be loaded from checkpoint
|
||
d_c = self.config.query_compression_dim
|
||
n_ih = self.config.indexer_num_heads
|
||
c_i = self.config.indexer_head_dim
|
||
self._q_up_weight = torch.randn(
|
||
d_c, n_ih * c_i, dtype=torch.bfloat16, device='cuda') * 0.02
|
||
self._w_head_weight = torch.randn(
|
||
self.config.hidden_size, n_ih, dtype=torch.bfloat16, device='cuda') * 0.02
|
||
|
||
q_I = torch.nn.functional.linear(c_Q, self._q_up_weight.T) # [T, n_ih * c_i] BF16
|
||
w_h = torch.nn.functional.linear(h_t, self._w_head_weight.T).float() # [T, n_ih] FP32
|
||
|
||
view = cache.read_indexer_view()
|
||
return run_indexer_score_topk(
|
||
q_I=q_I,
|
||
w_h=w_h,
|
||
indexer_view=view,
|
||
num_heads=self.config.indexer_num_heads,
|
||
head_dim=self.config.indexer_head_dim,
|
||
top_k=self.config.csa_top_k,
|
||
entries_per_block=cache.paged.schema.entries_per_block,
|
||
)
|