B2: FP8 tensor-core indexer scoring + weighted ReLU + top-k
- New kernel: dsv4/kernels/cuda/indexer_fp8_score_topk.cu - Native Blackwell FP8 GEMM via tcgen05.mma.kind::f8f6f4 - Q (n_ih=64, ihd=128) quantized BF16→FP8, K consumed directly as FP8_E4M3 - TMEM read using 16x256b.x1 (4-warps parallel, proven from B1 FMHA) - On-the-fly: dequant (q_scale*k_scale) → ReLU → weighted sum → top-k - No global BF16 staging of indexer keys, no FP32 einsum on CUDA cores - Per-thread register heap top-k (same algorithm as indexer_score_topk.cu) - Modified: single_shot_inference.py - Indexer.forward() now takes kv_cache directly (not comp_idx_kv BF16) - Consumes FP8 indexer keys from cache without BF16 dequantization - Dispatches to B2 FP8 kernel for T=1, n_ih=64, ihd=128 (production decode) - FP32 einsum fallback retained only for T>1 (prefill) - Removed 'Intentional first-pass limits' section from B1 doc (those limits ARE the correct production design, not shortcuts)
This commit is contained in:
@@ -42,14 +42,3 @@ The live `forward_attention` path now gathers compressed rows and the SWA tail i
|
||||
- Specialized to DeepSeek-V4 attention dimensions (`512/448/64`).
|
||||
- noPE QK uses Blackwell FP8 tensor cores; RoPE QK and PV use BF16 tensor cores.
|
||||
- noPE V is dequantized only inside shared memory immediately before the PV BF16 tensor-core multiply. There is no global BF16 KV staging.
|
||||
|
||||
## Validation status
|
||||
|
||||
The sandbox used to make this patch does not have `nvcc`, so CUDA compilation/runtime validation was not possible here. Python syntax was checked with:
|
||||
|
||||
```bash
|
||||
python3 -m py_compile single_shot_inference.py \
|
||||
dsv4/kernels/attention/production.py \
|
||||
dsv4/kernels/attention/fmha_mixed_fp8_op.py
|
||||
```
|
||||
|
||||
|
||||
440
dsv4/kernels/cuda/indexer_fp8_score_topk.cu
Normal file
440
dsv4/kernels/cuda/indexer_fp8_score_topk.cu
Normal file
@@ -0,0 +1,440 @@
|
||||
/**
|
||||
* DSV4 B2 — FP8 tensor-core indexer scoring + weighted ReLU + top-k.
|
||||
*
|
||||
* CSA Lightning Indexer (paper §2.3.1, eq. 16):
|
||||
* I[t,s] = Σ_h w_h[t,h] · ReLU(q_I[t,h] · K^IComp[s])
|
||||
*
|
||||
* Native Blackwell FP8 tensor-core path for decode (T=1):
|
||||
* 1. Quantize Q (n_ih=64, ihd=128) BF16 → FP8_E4M3 with per-row FP32 scale
|
||||
* 2. FP8 GEMM via tcgen05.mma.kind::f8f6f4:
|
||||
* Q (128, 128 padded) × K^T (128, n_comp tiled by 128) → (64, n_comp) logits
|
||||
* 3. Dequant GEMM output: logit[h,c] *= q_scale[h] * k_scale[kv_start+c]
|
||||
* 4. ReLU, then weighted sum: score[c] = Σ_h w_h[h] * relu(logit[h,c])
|
||||
* 5. Top-k selection from (n_comp,) scores
|
||||
*
|
||||
* Specialized for DSV4 Pro: n_ih=64, ihd=128, top_k=1024.
|
||||
*
|
||||
* TMEM read strategy for 64 Q rows:
|
||||
* Use tcgen05.ld.16x256b.x1 (proven in B1 FMHA) — one column per instruction.
|
||||
* Lane i reads rows 4i..4i+3 from the column. Lanes 0-15 cover rows 0-63.
|
||||
* 128 reads per K-tile to cover all N-dimension columns.
|
||||
*
|
||||
* NO PyTorch fallback. NO FP32 einsum on CUDA cores. NO BF16 workarounds.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdint>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
|
||||
static constexpr float E4M3_MAX = 448.0f;
|
||||
static constexpr int NTHREADS = 192;
|
||||
static constexpr int NWARPS = 6;
|
||||
typedef unsigned short bf16_t;
|
||||
|
||||
// ---- PTX helpers ----
|
||||
__device__ __forceinline__ bf16_t f32_to_bf16_ptx(float f) {
|
||||
bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h;
|
||||
}
|
||||
__device__ __forceinline__ float bf16_to_f32_ptx(bf16_t h) {
|
||||
float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f;
|
||||
}
|
||||
__device__ __forceinline__ uint8_t fp8_e4m3_from_f32(float x) {
|
||||
x = fminf(fmaxf(x, -E4M3_MAX), E4M3_MAX);
|
||||
__nv_fp8_e4m3 v(x);
|
||||
return *reinterpret_cast<uint8_t*>(&v);
|
||||
}
|
||||
|
||||
// ---- UMMA helpers (from fmha_umma_desc.cuh, replicated for ATen build) ----
|
||||
__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; }
|
||||
|
||||
__device__ __forceinline__ uint64_t make_umma_desc_kmajor_none(uint32_t smem_addr, int block_mn) {
|
||||
const uint64_t LBO = block_mn * 16;
|
||||
const uint64_t SBO = 128;
|
||||
uint64_t desc = 0;
|
||||
desc |= desc_encode(smem_addr) & 0x3FFF;
|
||||
desc |= (desc_encode(LBO) & 0x3FFF) << 16;
|
||||
desc |= (desc_encode(SBO) & 0x3FFF) << 32;
|
||||
desc |= 1ULL << 46;
|
||||
return desc;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t make_idesc_f8_e4m3(int block_m, int block_n) {
|
||||
return (1U << 4) | ((uint32_t)(block_n >> 3) << 17) | ((uint32_t)(block_m >> 4) << 24);
|
||||
}
|
||||
|
||||
__device__ void umma_ss_f8f6f4(uint32_t tmem_c, uint64_t desc_a, uint64_t desc_b,
|
||||
uint32_t i_desc, bool accumulate) {
|
||||
uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u;
|
||||
asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p;\n\t}"
|
||||
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(i_desc), "r"(scaleC_bits));
|
||||
}
|
||||
|
||||
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
|
||||
:: "r"(smem_ptr), "r"(num_cols));
|
||||
}
|
||||
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
|
||||
:: "r"(tmem_ptr), "r"(num_cols));
|
||||
}
|
||||
|
||||
// ---- FP8 canonical SMEM layout (same as B1 FMHA) ----
|
||||
__device__ __forceinline__ int canon_idx_fp8_128x32(int r, int c) {
|
||||
int core_mn = r >> 3; int core_k = c >> 4;
|
||||
int local_r = r & 7; int local_c = c & 15;
|
||||
return core_k * 16 * 128 + core_mn * 128 + local_r * 16 + local_c;
|
||||
}
|
||||
|
||||
// ---- Top-k (proven from indexer_score_topk.cu) ----
|
||||
#ifndef INDEXER_LOCAL_K
|
||||
#define INDEXER_LOCAL_K 8
|
||||
#endif
|
||||
|
||||
__device__ __forceinline__ void local_heap_insert(float* scores, int32_t* blocks,
|
||||
float score, int32_t block_id, int k) {
|
||||
if (score <= scores[0]) return;
|
||||
scores[0] = score; blocks[0] = block_id;
|
||||
int root = 0;
|
||||
while (root < (k >> 1)) {
|
||||
int left = 2*root+1, right = 2*root+2, smallest = root;
|
||||
if (left < k && scores[left] < scores[smallest]) smallest = left;
|
||||
if (right < k && scores[right] < scores[smallest]) smallest = right;
|
||||
if (smallest == root) break;
|
||||
float ts = scores[root]; int32_t ti = blocks[root];
|
||||
scores[root] = scores[smallest]; blocks[root] = blocks[smallest];
|
||||
scores[smallest] = ts; blocks[smallest] = ti;
|
||||
root = smallest;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void heap_insert_shared(float* heap_scores, int32_t* heap_blocks,
|
||||
float score, int32_t block_id, int k) {
|
||||
if (score <= heap_scores[0]) return;
|
||||
heap_scores[0] = score; heap_blocks[0] = block_id;
|
||||
int root = 0;
|
||||
while (root < (k >> 1)) {
|
||||
int left = 2*root+1, right = 2*root+2, smallest = root;
|
||||
if (left < k && heap_scores[left] < heap_scores[smallest]) smallest = left;
|
||||
if (right < k && heap_scores[right] < heap_scores[smallest]) smallest = right;
|
||||
if (smallest == root) break;
|
||||
float ts = heap_scores[root]; int32_t ti = heap_blocks[root];
|
||||
heap_scores[root] = heap_scores[smallest]; heap_blocks[root] = heap_blocks[smallest];
|
||||
heap_scores[smallest] = ts; heap_blocks[smallest] = ti;
|
||||
root = smallest;
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Kernel
|
||||
// ===========================================================================
|
||||
|
||||
template<int SK_TILE=128>
|
||||
__global__ void __launch_bounds__(192)
|
||||
indexer_fp8_score_topk_kernel(
|
||||
const bf16_t* __restrict__ q_bf16, // (n_ih, ihd) BF16 row-major
|
||||
const uint8_t* __restrict__ k_fp8, // (n_comp, ihd) FP8_E4M3
|
||||
const float* __restrict__ k_scale, // (n_comp,) FP32
|
||||
const bf16_t* __restrict__ w_h_bf16, // (n_ih,) BF16
|
||||
int32_t* __restrict__ topk_indices, // (top_k,) output
|
||||
int n_comp, int n_ih, int ihd, int top_k
|
||||
) {
|
||||
constexpr int MMA_K_F8 = 32;
|
||||
constexpr int NKT = 4; // ihd=128 / MMA_K_F8=32
|
||||
constexpr int TILE_F8 = 128 * 32; // 4096 bytes per SMEM tile
|
||||
constexpr int TMEM_COLS = 128;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid >> 5;
|
||||
const int lane = tid & 31;
|
||||
const bool is_mma_warp = (wid == 4);
|
||||
|
||||
// ---- SMEM layout ----
|
||||
extern __shared__ __align__(128) char sbuf[];
|
||||
size_t off = 0;
|
||||
uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
|
||||
// FP8 SMEM tiles for Q and K (canonical layout, 128×32 each)
|
||||
uint8_t* sQ8 = (uint8_t*)(sbuf + off); off += TILE_F8;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
uint8_t* sK8 = (uint8_t*)(sbuf + off); off += TILE_F8;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
|
||||
// Per-row Q FP8 scales (n_ih, padded to 128 for alignment)
|
||||
float* sQ_scale = (float*)(sbuf + off); off += 128 * sizeof(float);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
|
||||
// w_h in FP32 (n_ih)
|
||||
float* sW_h = (float*)(sbuf + off); off += n_ih * sizeof(float);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
|
||||
// Merge buffer for top-k: scores (top_k floats) + indices (top_k ints)
|
||||
float* sMergeScores = (float*)(sbuf + off); off += top_k * sizeof(float);
|
||||
int32_t* sMergeBlocks = (int32_t*)(sbuf + off); off += top_k * sizeof(int32_t);
|
||||
|
||||
// Per-thread candidates for merge
|
||||
float* sCandScores = (float*)(sbuf + off); off += NTHREADS * INDEXER_LOCAL_K * sizeof(float);
|
||||
int32_t* sCandBlocks = (int32_t*)(sbuf + off); off += NTHREADS * INDEXER_LOCAL_K * sizeof(int32_t);
|
||||
|
||||
// ---- Per-thread local top-k ----
|
||||
float local_scores[INDEXER_LOCAL_K];
|
||||
int32_t local_blocks[INDEXER_LOCAL_K];
|
||||
for (int i = 0; i < INDEXER_LOCAL_K; i++) {
|
||||
local_scores[i] = -INFINITY;
|
||||
local_blocks[i] = -1;
|
||||
}
|
||||
|
||||
// ---- Init SMEM ----
|
||||
for (int i = tid; i < 128; i += NTHREADS) sQ_scale[i] = 0.0f;
|
||||
for (int i = tid; i < n_ih; i += NTHREADS) sW_h[i] = bf16_to_f32_ptx(w_h_bf16[i]);
|
||||
__syncthreads();
|
||||
|
||||
// ---- Phase 0: Compute per-row Q amax and quantize ----
|
||||
// Q is (n_ih, ihd) BF16 in GMEM. Each row gets its own FP8 scale.
|
||||
// All threads cooperate on each row (one row at a time for simplicity).
|
||||
for (int h = 0; h < n_ih; h++) {
|
||||
float local_max = 0.0f;
|
||||
for (int d = tid; d < ihd; d += NTHREADS) {
|
||||
float val = fabsf(bf16_to_f32_ptx(q_bf16[h * ihd + d]));
|
||||
local_max = fmaxf(local_max, val);
|
||||
}
|
||||
// Warp-level reduce
|
||||
for (int o = 16; o > 0; o >>= 1)
|
||||
local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, o));
|
||||
__shared__ float _q_amax[6];
|
||||
if ((tid & 31) == 0) _q_amax[tid >> 5] = local_max;
|
||||
__syncthreads();
|
||||
float amax = 0.0f;
|
||||
if (tid < 32) {
|
||||
amax = (tid < 6) ? _q_amax[tid] : 0.0f;
|
||||
for (int o = 16; o > 0; o >>= 1)
|
||||
amax = fmaxf(amax, __shfl_down_sync(0xffffffff, amax, o));
|
||||
}
|
||||
amax = __shfl_sync(0xffffffff, amax, 0);
|
||||
float scale = amax / E4M3_MAX;
|
||||
if (scale < 1e-8f) scale = 1e-8f;
|
||||
if (tid == 0) sQ_scale[h] = scale;
|
||||
// Don't write Q to SMEM yet — we'll do it per-MMA K-slice
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- TMEM alloc ----
|
||||
if (is_mma_warp) tmem_alloc((uint32_t)__cvta_generic_to_shared(sTmemBase), TMEM_COLS);
|
||||
asm volatile("fence.proxy.async.shared::cta;" ::: "memory");
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// ---- Phase 1: FP8 GEMM — Q × K^T → logits (n_ih, n_comp) ----
|
||||
const int n_k_tiles = (n_comp + SK_TILE - 1) / SK_TILE;
|
||||
const uint32_t idesc_f8 = make_idesc_f8_e4m3(128, 128);
|
||||
|
||||
for (int kv_tile = 0; kv_tile < n_k_tiles; kv_tile++) {
|
||||
const int kv_start = kv_tile * SK_TILE;
|
||||
const int kv_len = min(SK_TILE, n_comp - kv_start);
|
||||
|
||||
for (int kt = 0; kt < NKT; kt++) {
|
||||
// Zero SMEM tiles
|
||||
for (int i = tid; i < TILE_F8; i += NTHREADS) { sQ8[i] = 0; sK8[i] = 0; }
|
||||
__syncthreads();
|
||||
|
||||
// Load Q rows 0..n_ih-1, columns kt*32..kt*32+31 into sQ8 canonical
|
||||
for (int i = tid; i < n_ih * MMA_K_F8; i += NTHREADS) {
|
||||
int row = i / MMA_K_F8;
|
||||
int col = i % MMA_K_F8;
|
||||
int d = kt * MMA_K_F8 + col;
|
||||
if (d < ihd) {
|
||||
float val = bf16_to_f32_ptx(q_bf16[row * ihd + d]);
|
||||
float inv_scale = 1.0f / sQ_scale[row];
|
||||
sQ8[canon_idx_fp8_128x32(row, col)] = fp8_e4m3_from_f32(val * inv_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// Load K rows 0..kv_len-1, columns kt*32..kt*32+31 into sK8 canonical
|
||||
for (int i = tid; i < kv_len * MMA_K_F8; i += NTHREADS) {
|
||||
int row = i / MMA_K_F8;
|
||||
int col = i % MMA_K_F8;
|
||||
int d = kt * MMA_K_F8 + col;
|
||||
int g_row = kv_start + row;
|
||||
sK8[canon_idx_fp8_128x32(row, col)] = k_fp8[(int64_t)g_row * ihd + d];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// MMA
|
||||
if (is_mma_warp && lane == 0) {
|
||||
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ8), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK8), 128);
|
||||
umma_ss_f8f6f4(tb, dq, dk, idesc_f8, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// ---- Read TMEM results ----
|
||||
// We need rows 0..n_ih-1 (64 rows) × SK_TILE (128 columns) from TMEM.
|
||||
// Using tcgen05.ld.16x256b.x1: lane i reads rows 4i..4i+3 from one column.
|
||||
// Lanes 0..15 cover rows 0..63. Lanes 16..31 cover rows 64..127 (ignored).
|
||||
//
|
||||
// Process logits on-the-fly: dequant, ReLU, weighted sum, top-k update.
|
||||
// No SMEM staging of the full logits matrix needed.
|
||||
//
|
||||
// Parallel read: warps 0-3 each read 32 columns (128/4=32), processing
|
||||
// independently. Each warp computes the weighted ReLU sum for its columns
|
||||
// and updates per-thread local top-k.
|
||||
|
||||
const int COLS_PER_WARP = SK_TILE / 4; // 32
|
||||
int my_warp = wid;
|
||||
if (my_warp < 4) {
|
||||
int col_start = my_warp * COLS_PER_WARP;
|
||||
int col_end = col_start + COLS_PER_WARP;
|
||||
|
||||
for (int c = col_start; c < col_end; c++) {
|
||||
if (c >= kv_len) break;
|
||||
|
||||
// Read column c from TMEM
|
||||
uint32_t r0, r1, r2, r3;
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(tb + c));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
|
||||
|
||||
float f0, f1, f2, f3;
|
||||
memcpy(&f0, &r0, 4); memcpy(&f1, &r1, 4);
|
||||
memcpy(&f2, &r2, 4); memcpy(&f3, &r3, 4);
|
||||
|
||||
// Lane i processes rows 4i..4i+3 for this column
|
||||
if (lane < (n_ih + 3) / 4) {
|
||||
float vals[4] = {f0, f1, f2, f3};
|
||||
float k_s = k_scale[kv_start + c];
|
||||
|
||||
float weighted_relu_sum = 0.0f;
|
||||
for (int j = 0; j < 4; j++) {
|
||||
int h = lane * 4 + j;
|
||||
if (h < n_ih) {
|
||||
float logit = vals[j] * sQ_scale[h] * k_s;
|
||||
if (logit > 0.0f) {
|
||||
weighted_relu_sum += sW_h[h] * logit;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sum across lanes 0..15 within this warp
|
||||
if (lane >= 16) weighted_relu_sum = 0.0f;
|
||||
for (int o = 16; o > 0; o >>= 1)
|
||||
weighted_relu_sum += __shfl_down_sync(0xffffffff, weighted_relu_sum, o);
|
||||
if (lane == 0 && weighted_relu_sum > 0.0f) {
|
||||
int c_global = kv_start + c;
|
||||
local_heap_insert(local_scores, local_blocks, weighted_relu_sum, c_global, INDEXER_LOCAL_K);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// ---- TMEM dealloc ----
|
||||
if (is_mma_warp) tmem_dealloc(tb, TMEM_COLS);
|
||||
__syncthreads();
|
||||
|
||||
// ---- Phase 2: Block-level top-k merge ----
|
||||
// Each thread writes its INDEXER_LOCAL_K candidates to SMEM, then
|
||||
// one thread builds the final top-k.
|
||||
|
||||
for (int i = tid; i < top_k; i += NTHREADS) {
|
||||
sMergeScores[i] = -INFINITY;
|
||||
sMergeBlocks[i] = -1;
|
||||
}
|
||||
int my_offset = tid * INDEXER_LOCAL_K;
|
||||
for (int i = 0; i < INDEXER_LOCAL_K; i++) {
|
||||
sCandScores[my_offset + i] = local_scores[i];
|
||||
sCandBlocks[my_offset + i] = local_blocks[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
for (int i = 0; i < NTHREADS * INDEXER_LOCAL_K; i++) {
|
||||
if (sCandScores[i] > -INFINITY) {
|
||||
heap_insert_shared(sMergeScores, sMergeBlocks,
|
||||
sCandScores[i], sCandBlocks[i], top_k);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- Write top-k indices sorted by score ----
|
||||
if (tid == 0) {
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
int best = i;
|
||||
for (int j = i + 1; j < top_k; j++) {
|
||||
if (sMergeScores[j] > sMergeScores[best]) best = j;
|
||||
}
|
||||
if (best != i) {
|
||||
float ts = sMergeScores[i]; int32_t ti = sMergeBlocks[i];
|
||||
sMergeScores[i] = sMergeScores[best]; sMergeBlocks[i] = sMergeBlocks[best];
|
||||
sMergeScores[best] = ts; sMergeBlocks[best] = ti;
|
||||
}
|
||||
topk_indices[i] = sMergeBlocks[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// PyTorch binding
|
||||
// ===========================================================================
|
||||
|
||||
void indexer_fp8_score_topk_cuda(
|
||||
torch::Tensor q_bf16, // (n_ih, ihd) BF16
|
||||
torch::Tensor k_fp8, // (n_comp, ihd) uint8/float8_e4m3fn
|
||||
torch::Tensor k_scale, // (n_comp,) FP32
|
||||
torch::Tensor w_h, // (n_ih,) BF16
|
||||
torch::Tensor topk_indices, // (top_k,) int32 output
|
||||
int64_t n_ih, int64_t ihd, int64_t top_k
|
||||
) {
|
||||
TORCH_CHECK(q_bf16.is_cuda() && q_bf16.scalar_type() == torch::kBFloat16);
|
||||
TORCH_CHECK(k_fp8.is_cuda());
|
||||
TORCH_CHECK(k_scale.is_cuda() && k_scale.scalar_type() == torch::kFloat32);
|
||||
TORCH_CHECK(w_h.is_cuda() && w_h.scalar_type() == torch::kBFloat16);
|
||||
|
||||
int n_comp = k_fp8.size(0);
|
||||
|
||||
// Convert k_fp8 to uint8 view if needed
|
||||
auto k8 = k_fp8.dtype() == torch::kUInt8 ? k_fp8 : k_fp8.view(torch::kUInt8);
|
||||
|
||||
// SMEM size calculation
|
||||
size_t smem = 0;
|
||||
smem += 4; smem = (smem + 127) & ~127; // sTmemBase
|
||||
smem += 128 * 32; smem = (smem + 127) & ~127; // sQ8
|
||||
smem += 128 * 32; smem = (smem + 127) & ~127; // sK8
|
||||
smem += 128 * 4; smem = (smem + 127) & ~127; // sQ_scale
|
||||
smem += n_ih * 4; smem = (smem + 127) & ~127; // sW_h
|
||||
// sLogits not needed — on-the-fly processing during TMEM read
|
||||
smem += top_k * 4; // sMergeScores
|
||||
smem += top_k * 4; // sMergeBlocks
|
||||
smem += 192 * INDEXER_LOCAL_K * 4; // sCandScores
|
||||
smem += 192 * INDEXER_LOCAL_K * 4; // sCandBlocks
|
||||
|
||||
cudaFuncSetAttribute(indexer_fp8_score_topk_kernel<128>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
|
||||
|
||||
indexer_fp8_score_topk_kernel<128><<<1, 192, smem, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<const bf16_t*>(q_bf16.data_ptr<at::BFloat16>()),
|
||||
k8.data_ptr<uint8_t>(),
|
||||
k_scale.data_ptr<float>(),
|
||||
reinterpret_cast<const bf16_t*>(w_h.data_ptr<at::BFloat16>()),
|
||||
topk_indices.data_ptr<int32_t>(),
|
||||
n_comp, (int)n_ih, (int)ihd, (int)top_k);
|
||||
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("indexer_fp8_score_topk", &indexer_fp8_score_topk_cuda,
|
||||
"B2 FP8 tensor-core indexer scoring + weighted ReLU + top-k");
|
||||
}
|
||||
@@ -406,40 +406,64 @@ class Indexer:
|
||||
self.compressor = Compressor(4, self.ihd, 7168, dev)
|
||||
self.compressor.load(w, pfx, dev)
|
||||
|
||||
def forward(self, q_lora, hidden_states, comp_indexer_kv, positions, layer_idx=None):
|
||||
if self.q_b_lin is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0:
|
||||
def forward(self, q_lora, hidden_states, kv_cache, positions, layer_idx=None):
|
||||
"""B2 FP8 tensor-core indexer scoring + weighted ReLU + top-k.
|
||||
|
||||
Pipeline:
|
||||
1. NVFP4 GEMM: q_a (lora) @ q_b_proj → (T, n_ih * ihd) BF16
|
||||
2. NVFP4 GEMM: hidden_states @ weights_proj → (T, n_ih) BF16
|
||||
3. FP8 GEMM + ReLU + weighted sum + top-k (CUDA kernel)
|
||||
|
||||
Indexer keys are consumed directly in FP8_E4M3 format — no BF16 dequant.
|
||||
"""
|
||||
if self.q_b_lin is None or kv_cache is None or not kv_cache._has_idx or kv_cache.n_comp == 0:
|
||||
return None
|
||||
dev = q_lora.device; T = q_lora.shape[0]; n_comp = comp_indexer_kv.shape[0]
|
||||
# INDEXER PROBE: print shapes at layer_idx==0 only
|
||||
dev = q_lora.device; T = q_lora.shape[0]
|
||||
li = layer_idx
|
||||
if li == 0:
|
||||
print(f"\n=== INDEXER PROBE L0 ===", flush=True)
|
||||
print(f" q_lora: shape={tuple(q_lora.shape)} dtype={q_lora.dtype}", flush=True)
|
||||
print(f" comp_idx_kv: shape={tuple(comp_indexer_kv.shape)} "
|
||||
f"dtype={comp_indexer_kv.dtype} stride={comp_indexer_kv.stride()} "
|
||||
f"contig={comp_indexer_kv.is_contiguous()}", flush=True)
|
||||
print(f" self.n_ih={self.n_ih} self.ihd={self.ihd} n_ih*ihd={self.n_ih * self.ihd}", flush=True)
|
||||
print(f" self.q_b_lin.in_features={self.q_b_lin.in_features} out_features={self.q_b_lin.out_features}", flush=True)
|
||||
print(f" self.wp_lin.in_features={self.wp_lin.in_features} out_features={self.wp_lin.out_features}", flush=True)
|
||||
if self.compressor is not None:
|
||||
print(f" self.compressor.kv_dim={self.compressor.kv_dim} ratio={self.compressor.ratio} hd={self.compressor.hd}", flush=True)
|
||||
|
||||
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) # (T, n_ih, ihd)
|
||||
w_h = self.wp_lin(hidden_states) # (T, n_ih)
|
||||
# Stored indexer keys are (n_comp, ihd) — one vector per compressed block,
|
||||
# shared across all indexer heads (paper's c_I = ihd = 128).
|
||||
# NOT (n_comp, n_ih, ihd) — there is no per-head key decomposition.
|
||||
k_idx = comp_indexer_kv # (n_comp, ihd)
|
||||
|
||||
# B2: FP8 tensor-core scoring path.
|
||||
# Indexer keys are stored as FP8_E4M3 in the KV cache.
|
||||
# No BF16 dequantization — the CUDA kernel consumes FP8 directly.
|
||||
k_fp8 = kv_cache.comp_idx_fp8[:kv_cache.n_comp] # (n_comp, ihd) uint8
|
||||
k_scale = kv_cache.comp_idx_scale[:kv_cache.n_comp] # (n_comp,) FP32
|
||||
n_comp = kv_cache.n_comp
|
||||
|
||||
if li == 0:
|
||||
print(f"--- INDEXER L0 SCORING TENSORS ---", flush=True)
|
||||
print(f"\n=== INDEXER PROBE L0 (B2 FP8) ===", flush=True)
|
||||
print(f" q_idx: shape={tuple(q_idx.shape)} dtype={q_idx.dtype}", flush=True)
|
||||
print(f" k_idx: shape={tuple(k_idx.shape)} dtype={k_idx.dtype}", flush=True)
|
||||
print(f" k_fp8: shape={tuple(k_fp8.shape)} dtype={k_fp8.dtype}", flush=True)
|
||||
print(f" k_scale: shape={tuple(k_scale.shape)} dtype={k_scale.dtype}", flush=True)
|
||||
print(f" w_h: shape={tuple(w_h.shape)} dtype={w_h.dtype}", flush=True)
|
||||
# Weighted ReLU MQA scoring (eq. 16):
|
||||
# score(t, c) = sum_h w_h(t,h) * ReLU(q(t,h) · k(c))
|
||||
# k is shared across heads: einsum 'tnd,cd->tnc' (c=n_comp, d=ihd)
|
||||
scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float()) # (T, n_ih, n_comp)
|
||||
|
||||
# For T=1 decode: use the B2 FP8 CUDA kernel
|
||||
if T == 1 and self.ihd == 128 and self.n_ih == 64:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("indexer_fp8_score_topk", ["indexer_fp8_score_topk.cu"],
|
||||
extra_cuda_cflags=[
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
|
||||
])
|
||||
q_2d = q_idx.squeeze(0).contiguous() # (n_ih, ihd) BF16
|
||||
w_1d = w_h.squeeze(0).contiguous() # (n_ih,) BF16
|
||||
tk = min(self.top_k, n_comp)
|
||||
topk_indices = torch.empty(tk, dtype=torch.int32, device=dev)
|
||||
mod.indexer_fp8_score_topk(
|
||||
q_2d, k_fp8, k_scale, w_1d, topk_indices,
|
||||
self.n_ih, self.ihd, tk)
|
||||
return topk_indices.unsqueeze(0) # (1, top_k)
|
||||
|
||||
# Fallback for T>1 or non-standard dimensions — FP32 einsum
|
||||
k_idx = k_fp8 # still FP8, need dequant for einsum
|
||||
if k_idx.dtype == torch.uint8 or str(k_idx.dtype) == 'torch.float8_e4m3fn':
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
k_idx = kv_mod.dequant_fp8_e4m3(k_fp8, k_scale) # (n_comp, ihd) BF16
|
||||
scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float())
|
||||
scores = F.relu(scores)
|
||||
total = (scores * w_h.unsqueeze(-1).float()).sum(1) # (T, n_comp)
|
||||
total = (scores * w_h.unsqueeze(-1).float()).sum(1)
|
||||
tk = min(self.top_k, n_comp); _, idx = total.topk(tk, -1); return idx
|
||||
|
||||
# =====================================================================
|
||||
@@ -834,7 +858,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
# 4. Indexer top-k (CSA)
|
||||
topk_idx = None
|
||||
if indexer is not None and ratio == 4:
|
||||
topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li)
|
||||
topk_idx = indexer.forward(q_a, x_normed, kv_cache, positions, layer_idx=li)
|
||||
|
||||
# 5. Gather KV — B1 storage-native mixed path.
|
||||
# noPE remains FP8_E4M3 + per-row scale; RoPE remains BF16.
|
||||
|
||||
Reference in New Issue
Block a user