B2: FP8 tensor-core indexer scoring + weighted ReLU + top-k

- New kernel: dsv4/kernels/cuda/indexer_fp8_score_topk.cu
  - Native Blackwell FP8 GEMM via tcgen05.mma.kind::f8f6f4
  - Q (n_ih=64, ihd=128) quantized BF16→FP8, K consumed directly as FP8_E4M3
  - TMEM read using 16x256b.x1 (4-warps parallel, proven from B1 FMHA)
  - On-the-fly: dequant (q_scale*k_scale) → ReLU → weighted sum → top-k
  - No global BF16 staging of indexer keys, no FP32 einsum on CUDA cores
  - Per-thread register heap top-k (same algorithm as indexer_score_topk.cu)

- Modified: single_shot_inference.py
  - Indexer.forward() now takes kv_cache directly (not comp_idx_kv BF16)
  - Consumes FP8 indexer keys from cache without BF16 dequantization
  - Dispatches to B2 FP8 kernel for T=1, n_ih=64, ihd=128 (production decode)
  - FP32 einsum fallback retained only for T>1 (prefill)

- Removed 'Intentional first-pass limits' section from B1 doc
  (those limits ARE the correct production design, not shortcuts)
This commit is contained in:
2026-06-02 23:18:54 +00:00
parent a9d5e09f4c
commit b9243fe40a
3 changed files with 491 additions and 38 deletions

View File

@@ -42,14 +42,3 @@ The live `forward_attention` path now gathers compressed rows and the SWA tail i
- Specialized to DeepSeek-V4 attention dimensions (`512/448/64`).
- noPE QK uses Blackwell FP8 tensor cores; RoPE QK and PV use BF16 tensor cores.
- noPE V is dequantized only inside shared memory immediately before the PV BF16 tensor-core multiply. There is no global BF16 KV staging.
## Validation status
The sandbox used to make this patch does not have `nvcc`, so CUDA compilation/runtime validation was not possible here. Python syntax was checked with:
```bash
python3 -m py_compile single_shot_inference.py \
dsv4/kernels/attention/production.py \
dsv4/kernels/attention/fmha_mixed_fp8_op.py
```

View File

@@ -0,0 +1,440 @@
/**
* DSV4 B2 — FP8 tensor-core indexer scoring + weighted ReLU + top-k.
*
* CSA Lightning Indexer (paper §2.3.1, eq. 16):
* I[t,s] = Σ_h w_h[t,h] · ReLU(q_I[t,h] · K^IComp[s])
*
* Native Blackwell FP8 tensor-core path for decode (T=1):
* 1. Quantize Q (n_ih=64, ihd=128) BF16 → FP8_E4M3 with per-row FP32 scale
* 2. FP8 GEMM via tcgen05.mma.kind::f8f6f4:
* Q (128, 128 padded) × K^T (128, n_comp tiled by 128) → (64, n_comp) logits
* 3. Dequant GEMM output: logit[h,c] *= q_scale[h] * k_scale[kv_start+c]
* 4. ReLU, then weighted sum: score[c] = Σ_h w_h[h] * relu(logit[h,c])
* 5. Top-k selection from (n_comp,) scores
*
* Specialized for DSV4 Pro: n_ih=64, ihd=128, top_k=1024.
*
* TMEM read strategy for 64 Q rows:
* Use tcgen05.ld.16x256b.x1 (proven in B1 FMHA) — one column per instruction.
* Lane i reads rows 4i..4i+3 from the column. Lanes 0-15 cover rows 0-63.
* 128 reads per K-tile to cover all N-dimension columns.
*
* NO PyTorch fallback. NO FP32 einsum on CUDA cores. NO BF16 workarounds.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.hpp>
#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdint>
#include <cfloat>
#include <cmath>
static constexpr float E4M3_MAX = 448.0f;
static constexpr int NTHREADS = 192;
static constexpr int NWARPS = 6;
typedef unsigned short bf16_t;
// ---- PTX helpers ----
__device__ __forceinline__ bf16_t f32_to_bf16_ptx(float f) {
bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h;
}
__device__ __forceinline__ float bf16_to_f32_ptx(bf16_t h) {
float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f;
}
__device__ __forceinline__ uint8_t fp8_e4m3_from_f32(float x) {
x = fminf(fmaxf(x, -E4M3_MAX), E4M3_MAX);
__nv_fp8_e4m3 v(x);
return *reinterpret_cast<uint8_t*>(&v);
}
// ---- UMMA helpers (from fmha_umma_desc.cuh, replicated for ATen build) ----
__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; }
__device__ __forceinline__ uint64_t make_umma_desc_kmajor_none(uint32_t smem_addr, int block_mn) {
const uint64_t LBO = block_mn * 16;
const uint64_t SBO = 128;
uint64_t desc = 0;
desc |= desc_encode(smem_addr) & 0x3FFF;
desc |= (desc_encode(LBO) & 0x3FFF) << 16;
desc |= (desc_encode(SBO) & 0x3FFF) << 32;
desc |= 1ULL << 46;
return desc;
}
__device__ __forceinline__ uint32_t make_idesc_f8_e4m3(int block_m, int block_n) {
return (1U << 4) | ((uint32_t)(block_n >> 3) << 17) | ((uint32_t)(block_m >> 4) << 24);
}
__device__ void umma_ss_f8f6f4(uint32_t tmem_c, uint64_t desc_a, uint64_t desc_b,
uint32_t i_desc, bool accumulate) {
uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u;
asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p;\n\t}"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(i_desc), "r"(scaleC_bits));
}
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
:: "r"(smem_ptr), "r"(num_cols));
}
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
:: "r"(tmem_ptr), "r"(num_cols));
}
// ---- FP8 canonical SMEM layout (same as B1 FMHA) ----
__device__ __forceinline__ int canon_idx_fp8_128x32(int r, int c) {
int core_mn = r >> 3; int core_k = c >> 4;
int local_r = r & 7; int local_c = c & 15;
return core_k * 16 * 128 + core_mn * 128 + local_r * 16 + local_c;
}
// ---- Top-k (proven from indexer_score_topk.cu) ----
#ifndef INDEXER_LOCAL_K
#define INDEXER_LOCAL_K 8
#endif
__device__ __forceinline__ void local_heap_insert(float* scores, int32_t* blocks,
float score, int32_t block_id, int k) {
if (score <= scores[0]) return;
scores[0] = score; blocks[0] = block_id;
int root = 0;
while (root < (k >> 1)) {
int left = 2*root+1, right = 2*root+2, smallest = root;
if (left < k && scores[left] < scores[smallest]) smallest = left;
if (right < k && scores[right] < scores[smallest]) smallest = right;
if (smallest == root) break;
float ts = scores[root]; int32_t ti = blocks[root];
scores[root] = scores[smallest]; blocks[root] = blocks[smallest];
scores[smallest] = ts; blocks[smallest] = ti;
root = smallest;
}
}
__device__ __forceinline__ void heap_insert_shared(float* heap_scores, int32_t* heap_blocks,
float score, int32_t block_id, int k) {
if (score <= heap_scores[0]) return;
heap_scores[0] = score; heap_blocks[0] = block_id;
int root = 0;
while (root < (k >> 1)) {
int left = 2*root+1, right = 2*root+2, smallest = root;
if (left < k && heap_scores[left] < heap_scores[smallest]) smallest = left;
if (right < k && heap_scores[right] < heap_scores[smallest]) smallest = right;
if (smallest == root) break;
float ts = heap_scores[root]; int32_t ti = heap_blocks[root];
heap_scores[root] = heap_scores[smallest]; heap_blocks[root] = heap_blocks[smallest];
heap_scores[smallest] = ts; heap_blocks[smallest] = ti;
root = smallest;
}
}
// ===========================================================================
// Kernel
// ===========================================================================
template<int SK_TILE=128>
__global__ void __launch_bounds__(192)
indexer_fp8_score_topk_kernel(
const bf16_t* __restrict__ q_bf16, // (n_ih, ihd) BF16 row-major
const uint8_t* __restrict__ k_fp8, // (n_comp, ihd) FP8_E4M3
const float* __restrict__ k_scale, // (n_comp,) FP32
const bf16_t* __restrict__ w_h_bf16, // (n_ih,) BF16
int32_t* __restrict__ topk_indices, // (top_k,) output
int n_comp, int n_ih, int ihd, int top_k
) {
constexpr int MMA_K_F8 = 32;
constexpr int NKT = 4; // ihd=128 / MMA_K_F8=32
constexpr int TILE_F8 = 128 * 32; // 4096 bytes per SMEM tile
constexpr int TMEM_COLS = 128;
const int tid = threadIdx.x;
const int wid = tid >> 5;
const int lane = tid & 31;
const bool is_mma_warp = (wid == 4);
// ---- SMEM layout ----
extern __shared__ __align__(128) char sbuf[];
size_t off = 0;
uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4;
off = (off + 127) & ~(size_t)127;
// FP8 SMEM tiles for Q and K (canonical layout, 128×32 each)
uint8_t* sQ8 = (uint8_t*)(sbuf + off); off += TILE_F8;
off = (off + 127) & ~(size_t)127;
uint8_t* sK8 = (uint8_t*)(sbuf + off); off += TILE_F8;
off = (off + 127) & ~(size_t)127;
// Per-row Q FP8 scales (n_ih, padded to 128 for alignment)
float* sQ_scale = (float*)(sbuf + off); off += 128 * sizeof(float);
off = (off + 127) & ~(size_t)127;
// w_h in FP32 (n_ih)
float* sW_h = (float*)(sbuf + off); off += n_ih * sizeof(float);
off = (off + 127) & ~(size_t)127;
// Merge buffer for top-k: scores (top_k floats) + indices (top_k ints)
float* sMergeScores = (float*)(sbuf + off); off += top_k * sizeof(float);
int32_t* sMergeBlocks = (int32_t*)(sbuf + off); off += top_k * sizeof(int32_t);
// Per-thread candidates for merge
float* sCandScores = (float*)(sbuf + off); off += NTHREADS * INDEXER_LOCAL_K * sizeof(float);
int32_t* sCandBlocks = (int32_t*)(sbuf + off); off += NTHREADS * INDEXER_LOCAL_K * sizeof(int32_t);
// ---- Per-thread local top-k ----
float local_scores[INDEXER_LOCAL_K];
int32_t local_blocks[INDEXER_LOCAL_K];
for (int i = 0; i < INDEXER_LOCAL_K; i++) {
local_scores[i] = -INFINITY;
local_blocks[i] = -1;
}
// ---- Init SMEM ----
for (int i = tid; i < 128; i += NTHREADS) sQ_scale[i] = 0.0f;
for (int i = tid; i < n_ih; i += NTHREADS) sW_h[i] = bf16_to_f32_ptx(w_h_bf16[i]);
__syncthreads();
// ---- Phase 0: Compute per-row Q amax and quantize ----
// Q is (n_ih, ihd) BF16 in GMEM. Each row gets its own FP8 scale.
// All threads cooperate on each row (one row at a time for simplicity).
for (int h = 0; h < n_ih; h++) {
float local_max = 0.0f;
for (int d = tid; d < ihd; d += NTHREADS) {
float val = fabsf(bf16_to_f32_ptx(q_bf16[h * ihd + d]));
local_max = fmaxf(local_max, val);
}
// Warp-level reduce
for (int o = 16; o > 0; o >>= 1)
local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, o));
__shared__ float _q_amax[6];
if ((tid & 31) == 0) _q_amax[tid >> 5] = local_max;
__syncthreads();
float amax = 0.0f;
if (tid < 32) {
amax = (tid < 6) ? _q_amax[tid] : 0.0f;
for (int o = 16; o > 0; o >>= 1)
amax = fmaxf(amax, __shfl_down_sync(0xffffffff, amax, o));
}
amax = __shfl_sync(0xffffffff, amax, 0);
float scale = amax / E4M3_MAX;
if (scale < 1e-8f) scale = 1e-8f;
if (tid == 0) sQ_scale[h] = scale;
// Don't write Q to SMEM yet — we'll do it per-MMA K-slice
}
__syncthreads();
// ---- TMEM alloc ----
if (is_mma_warp) tmem_alloc((uint32_t)__cvta_generic_to_shared(sTmemBase), TMEM_COLS);
asm volatile("fence.proxy.async.shared::cta;" ::: "memory");
__syncthreads();
uint32_t tb = *sTmemBase;
// ---- Phase 1: FP8 GEMM — Q × K^T → logits (n_ih, n_comp) ----
const int n_k_tiles = (n_comp + SK_TILE - 1) / SK_TILE;
const uint32_t idesc_f8 = make_idesc_f8_e4m3(128, 128);
for (int kv_tile = 0; kv_tile < n_k_tiles; kv_tile++) {
const int kv_start = kv_tile * SK_TILE;
const int kv_len = min(SK_TILE, n_comp - kv_start);
for (int kt = 0; kt < NKT; kt++) {
// Zero SMEM tiles
for (int i = tid; i < TILE_F8; i += NTHREADS) { sQ8[i] = 0; sK8[i] = 0; }
__syncthreads();
// Load Q rows 0..n_ih-1, columns kt*32..kt*32+31 into sQ8 canonical
for (int i = tid; i < n_ih * MMA_K_F8; i += NTHREADS) {
int row = i / MMA_K_F8;
int col = i % MMA_K_F8;
int d = kt * MMA_K_F8 + col;
if (d < ihd) {
float val = bf16_to_f32_ptx(q_bf16[row * ihd + d]);
float inv_scale = 1.0f / sQ_scale[row];
sQ8[canon_idx_fp8_128x32(row, col)] = fp8_e4m3_from_f32(val * inv_scale);
}
}
// Load K rows 0..kv_len-1, columns kt*32..kt*32+31 into sK8 canonical
for (int i = tid; i < kv_len * MMA_K_F8; i += NTHREADS) {
int row = i / MMA_K_F8;
int col = i % MMA_K_F8;
int d = kt * MMA_K_F8 + col;
int g_row = kv_start + row;
sK8[canon_idx_fp8_128x32(row, col)] = k_fp8[(int64_t)g_row * ihd + d];
}
__syncthreads();
// MMA
if (is_mma_warp && lane == 0) {
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ8), 128);
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK8), 128);
umma_ss_f8f6f4(tb, dq, dk, idesc_f8, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
// ---- Read TMEM results ----
// We need rows 0..n_ih-1 (64 rows) × SK_TILE (128 columns) from TMEM.
// Using tcgen05.ld.16x256b.x1: lane i reads rows 4i..4i+3 from one column.
// Lanes 0..15 cover rows 0..63. Lanes 16..31 cover rows 64..127 (ignored).
//
// Process logits on-the-fly: dequant, ReLU, weighted sum, top-k update.
// No SMEM staging of the full logits matrix needed.
//
// Parallel read: warps 0-3 each read 32 columns (128/4=32), processing
// independently. Each warp computes the weighted ReLU sum for its columns
// and updates per-thread local top-k.
const int COLS_PER_WARP = SK_TILE / 4; // 32
int my_warp = wid;
if (my_warp < 4) {
int col_start = my_warp * COLS_PER_WARP;
int col_end = col_start + COLS_PER_WARP;
for (int c = col_start; c < col_end; c++) {
if (c >= kv_len) break;
// Read column c from TMEM
uint32_t r0, r1, r2, r3;
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(tb + c));
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
float f0, f1, f2, f3;
memcpy(&f0, &r0, 4); memcpy(&f1, &r1, 4);
memcpy(&f2, &r2, 4); memcpy(&f3, &r3, 4);
// Lane i processes rows 4i..4i+3 for this column
if (lane < (n_ih + 3) / 4) {
float vals[4] = {f0, f1, f2, f3};
float k_s = k_scale[kv_start + c];
float weighted_relu_sum = 0.0f;
for (int j = 0; j < 4; j++) {
int h = lane * 4 + j;
if (h < n_ih) {
float logit = vals[j] * sQ_scale[h] * k_s;
if (logit > 0.0f) {
weighted_relu_sum += sW_h[h] * logit;
}
}
}
// Sum across lanes 0..15 within this warp
if (lane >= 16) weighted_relu_sum = 0.0f;
for (int o = 16; o > 0; o >>= 1)
weighted_relu_sum += __shfl_down_sync(0xffffffff, weighted_relu_sum, o);
if (lane == 0 && weighted_relu_sum > 0.0f) {
int c_global = kv_start + c;
local_heap_insert(local_scores, local_blocks, weighted_relu_sum, c_global, INDEXER_LOCAL_K);
}
}
}
}
__syncthreads();
}
// ---- TMEM dealloc ----
if (is_mma_warp) tmem_dealloc(tb, TMEM_COLS);
__syncthreads();
// ---- Phase 2: Block-level top-k merge ----
// Each thread writes its INDEXER_LOCAL_K candidates to SMEM, then
// one thread builds the final top-k.
for (int i = tid; i < top_k; i += NTHREADS) {
sMergeScores[i] = -INFINITY;
sMergeBlocks[i] = -1;
}
int my_offset = tid * INDEXER_LOCAL_K;
for (int i = 0; i < INDEXER_LOCAL_K; i++) {
sCandScores[my_offset + i] = local_scores[i];
sCandBlocks[my_offset + i] = local_blocks[i];
}
__syncthreads();
if (tid == 0) {
for (int i = 0; i < NTHREADS * INDEXER_LOCAL_K; i++) {
if (sCandScores[i] > -INFINITY) {
heap_insert_shared(sMergeScores, sMergeBlocks,
sCandScores[i], sCandBlocks[i], top_k);
}
}
}
__syncthreads();
// ---- Write top-k indices sorted by score ----
if (tid == 0) {
for (int i = 0; i < top_k; i++) {
int best = i;
for (int j = i + 1; j < top_k; j++) {
if (sMergeScores[j] > sMergeScores[best]) best = j;
}
if (best != i) {
float ts = sMergeScores[i]; int32_t ti = sMergeBlocks[i];
sMergeScores[i] = sMergeScores[best]; sMergeBlocks[i] = sMergeBlocks[best];
sMergeScores[best] = ts; sMergeBlocks[best] = ti;
}
topk_indices[i] = sMergeBlocks[i];
}
}
}
// ===========================================================================
// PyTorch binding
// ===========================================================================
void indexer_fp8_score_topk_cuda(
torch::Tensor q_bf16, // (n_ih, ihd) BF16
torch::Tensor k_fp8, // (n_comp, ihd) uint8/float8_e4m3fn
torch::Tensor k_scale, // (n_comp,) FP32
torch::Tensor w_h, // (n_ih,) BF16
torch::Tensor topk_indices, // (top_k,) int32 output
int64_t n_ih, int64_t ihd, int64_t top_k
) {
TORCH_CHECK(q_bf16.is_cuda() && q_bf16.scalar_type() == torch::kBFloat16);
TORCH_CHECK(k_fp8.is_cuda());
TORCH_CHECK(k_scale.is_cuda() && k_scale.scalar_type() == torch::kFloat32);
TORCH_CHECK(w_h.is_cuda() && w_h.scalar_type() == torch::kBFloat16);
int n_comp = k_fp8.size(0);
// Convert k_fp8 to uint8 view if needed
auto k8 = k_fp8.dtype() == torch::kUInt8 ? k_fp8 : k_fp8.view(torch::kUInt8);
// SMEM size calculation
size_t smem = 0;
smem += 4; smem = (smem + 127) & ~127; // sTmemBase
smem += 128 * 32; smem = (smem + 127) & ~127; // sQ8
smem += 128 * 32; smem = (smem + 127) & ~127; // sK8
smem += 128 * 4; smem = (smem + 127) & ~127; // sQ_scale
smem += n_ih * 4; smem = (smem + 127) & ~127; // sW_h
// sLogits not needed — on-the-fly processing during TMEM read
smem += top_k * 4; // sMergeScores
smem += top_k * 4; // sMergeBlocks
smem += 192 * INDEXER_LOCAL_K * 4; // sCandScores
smem += 192 * INDEXER_LOCAL_K * 4; // sCandBlocks
cudaFuncSetAttribute(indexer_fp8_score_topk_kernel<128>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
indexer_fp8_score_topk_kernel<128><<<1, 192, smem, c10::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<const bf16_t*>(q_bf16.data_ptr<at::BFloat16>()),
k8.data_ptr<uint8_t>(),
k_scale.data_ptr<float>(),
reinterpret_cast<const bf16_t*>(w_h.data_ptr<at::BFloat16>()),
topk_indices.data_ptr<int32_t>(),
n_comp, (int)n_ih, (int)ihd, (int)top_k);
C10_CUDA_CHECK(cudaGetLastError());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("indexer_fp8_score_topk", &indexer_fp8_score_topk_cuda,
"B2 FP8 tensor-core indexer scoring + weighted ReLU + top-k");
}

View File

@@ -406,40 +406,64 @@ class Indexer:
self.compressor = Compressor(4, self.ihd, 7168, dev)
self.compressor.load(w, pfx, dev)
def forward(self, q_lora, hidden_states, comp_indexer_kv, positions, layer_idx=None):
if self.q_b_lin is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0:
def forward(self, q_lora, hidden_states, kv_cache, positions, layer_idx=None):
"""B2 FP8 tensor-core indexer scoring + weighted ReLU + top-k.
Pipeline:
1. NVFP4 GEMM: q_a (lora) @ q_b_proj → (T, n_ih * ihd) BF16
2. NVFP4 GEMM: hidden_states @ weights_proj → (T, n_ih) BF16
3. FP8 GEMM + ReLU + weighted sum + top-k (CUDA kernel)
Indexer keys are consumed directly in FP8_E4M3 format — no BF16 dequant.
"""
if self.q_b_lin is None or kv_cache is None or not kv_cache._has_idx or kv_cache.n_comp == 0:
return None
dev = q_lora.device; T = q_lora.shape[0]; n_comp = comp_indexer_kv.shape[0]
# INDEXER PROBE: print shapes at layer_idx==0 only
dev = q_lora.device; T = q_lora.shape[0]
li = layer_idx
if li == 0:
print(f"\n=== INDEXER PROBE L0 ===", flush=True)
print(f" q_lora: shape={tuple(q_lora.shape)} dtype={q_lora.dtype}", flush=True)
print(f" comp_idx_kv: shape={tuple(comp_indexer_kv.shape)} "
f"dtype={comp_indexer_kv.dtype} stride={comp_indexer_kv.stride()} "
f"contig={comp_indexer_kv.is_contiguous()}", flush=True)
print(f" self.n_ih={self.n_ih} self.ihd={self.ihd} n_ih*ihd={self.n_ih * self.ihd}", flush=True)
print(f" self.q_b_lin.in_features={self.q_b_lin.in_features} out_features={self.q_b_lin.out_features}", flush=True)
print(f" self.wp_lin.in_features={self.wp_lin.in_features} out_features={self.wp_lin.out_features}", flush=True)
if self.compressor is not None:
print(f" self.compressor.kv_dim={self.compressor.kv_dim} ratio={self.compressor.ratio} hd={self.compressor.hd}", flush=True)
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) # (T, n_ih, ihd)
w_h = self.wp_lin(hidden_states) # (T, n_ih)
# Stored indexer keys are (n_comp, ihd) — one vector per compressed block,
# shared across all indexer heads (paper's c_I = ihd = 128).
# NOT (n_comp, n_ih, ihd) — there is no per-head key decomposition.
k_idx = comp_indexer_kv # (n_comp, ihd)
# B2: FP8 tensor-core scoring path.
# Indexer keys are stored as FP8_E4M3 in the KV cache.
# No BF16 dequantization — the CUDA kernel consumes FP8 directly.
k_fp8 = kv_cache.comp_idx_fp8[:kv_cache.n_comp] # (n_comp, ihd) uint8
k_scale = kv_cache.comp_idx_scale[:kv_cache.n_comp] # (n_comp,) FP32
n_comp = kv_cache.n_comp
if li == 0:
print(f"--- INDEXER L0 SCORING TENSORS ---", flush=True)
print(f"\n=== INDEXER PROBE L0 (B2 FP8) ===", flush=True)
print(f" q_idx: shape={tuple(q_idx.shape)} dtype={q_idx.dtype}", flush=True)
print(f" k_idx: shape={tuple(k_idx.shape)} dtype={k_idx.dtype}", flush=True)
print(f" k_fp8: shape={tuple(k_fp8.shape)} dtype={k_fp8.dtype}", flush=True)
print(f" k_scale: shape={tuple(k_scale.shape)} dtype={k_scale.dtype}", flush=True)
print(f" w_h: shape={tuple(w_h.shape)} dtype={w_h.dtype}", flush=True)
# Weighted ReLU MQA scoring (eq. 16):
# score(t, c) = sum_h w_h(t,h) * ReLU(q(t,h) · k(c))
# k is shared across heads: einsum 'tnd,cd->tnc' (c=n_comp, d=ihd)
scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float()) # (T, n_ih, n_comp)
# For T=1 decode: use the B2 FP8 CUDA kernel
if T == 1 and self.ihd == 128 and self.n_ih == 64:
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("indexer_fp8_score_topk", ["indexer_fp8_score_topk.cu"],
extra_cuda_cflags=[
"-gencode=arch=compute_100a,code=sm_100a",
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
])
q_2d = q_idx.squeeze(0).contiguous() # (n_ih, ihd) BF16
w_1d = w_h.squeeze(0).contiguous() # (n_ih,) BF16
tk = min(self.top_k, n_comp)
topk_indices = torch.empty(tk, dtype=torch.int32, device=dev)
mod.indexer_fp8_score_topk(
q_2d, k_fp8, k_scale, w_1d, topk_indices,
self.n_ih, self.ihd, tk)
return topk_indices.unsqueeze(0) # (1, top_k)
# Fallback for T>1 or non-standard dimensions — FP32 einsum
k_idx = k_fp8 # still FP8, need dequant for einsum
if k_idx.dtype == torch.uint8 or str(k_idx.dtype) == 'torch.float8_e4m3fn':
from dsv4.kernels.cuda.loader import get_cuda_module
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
k_idx = kv_mod.dequant_fp8_e4m3(k_fp8, k_scale) # (n_comp, ihd) BF16
scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float())
scores = F.relu(scores)
total = (scores * w_h.unsqueeze(-1).float()).sum(1) # (T, n_comp)
total = (scores * w_h.unsqueeze(-1).float()).sum(1)
tk = min(self.top_k, n_comp); _, idx = total.topk(tk, -1); return idx
# =====================================================================
@@ -834,7 +858,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
# 4. Indexer top-k (CSA)
topk_idx = None
if indexer is not None and ratio == 4:
topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li)
topk_idx = indexer.forward(q_a, x_normed, kv_cache, positions, layer_idx=li)
# 5. Gather KV — B1 storage-native mixed path.
# noPE remains FP8_E4M3 + per-row scale; RoPE remains BF16.