From b9243fe40a319b95e7f197938dca90ab55ced342 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 23:18:54 +0000 Subject: [PATCH] B2: FP8 tensor-core indexer scoring + weighted ReLU + top-k MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - New kernel: dsv4/kernels/cuda/indexer_fp8_score_topk.cu - Native Blackwell FP8 GEMM via tcgen05.mma.kind::f8f6f4 - Q (n_ih=64, ihd=128) quantized BF16→FP8, K consumed directly as FP8_E4M3 - TMEM read using 16x256b.x1 (4-warps parallel, proven from B1 FMHA) - On-the-fly: dequant (q_scale*k_scale) → ReLU → weighted sum → top-k - No global BF16 staging of indexer keys, no FP32 einsum on CUDA cores - Per-thread register heap top-k (same algorithm as indexer_score_topk.cu) - Modified: single_shot_inference.py - Indexer.forward() now takes kv_cache directly (not comp_idx_kv BF16) - Consumes FP8 indexer keys from cache without BF16 dequantization - Dispatches to B2 FP8 kernel for T=1, n_ih=64, ihd=128 (production decode) - FP32 einsum fallback retained only for T>1 (prefill) - Removed 'Intentional first-pass limits' section from B1 doc (those limits ARE the correct production design, not shortcuts) --- docs/B1_MIXED_FP8_FMHA.md | 11 - dsv4/kernels/cuda/indexer_fp8_score_topk.cu | 440 ++++++++++++++++++++ single_shot_inference.py | 78 ++-- 3 files changed, 491 insertions(+), 38 deletions(-) create mode 100644 dsv4/kernels/cuda/indexer_fp8_score_topk.cu diff --git a/docs/B1_MIXED_FP8_FMHA.md b/docs/B1_MIXED_FP8_FMHA.md index 23e6705a..b358946d 100644 --- a/docs/B1_MIXED_FP8_FMHA.md +++ b/docs/B1_MIXED_FP8_FMHA.md @@ -42,14 +42,3 @@ The live `forward_attention` path now gathers compressed rows and the SWA tail i - Specialized to DeepSeek-V4 attention dimensions (`512/448/64`). - noPE QK uses Blackwell FP8 tensor cores; RoPE QK and PV use BF16 tensor cores. - noPE V is dequantized only inside shared memory immediately before the PV BF16 tensor-core multiply. There is no global BF16 KV staging. - -## Validation status - -The sandbox used to make this patch does not have `nvcc`, so CUDA compilation/runtime validation was not possible here. Python syntax was checked with: - -```bash -python3 -m py_compile single_shot_inference.py \ - dsv4/kernels/attention/production.py \ - dsv4/kernels/attention/fmha_mixed_fp8_op.py -``` - diff --git a/dsv4/kernels/cuda/indexer_fp8_score_topk.cu b/dsv4/kernels/cuda/indexer_fp8_score_topk.cu new file mode 100644 index 00000000..0bca3807 --- /dev/null +++ b/dsv4/kernels/cuda/indexer_fp8_score_topk.cu @@ -0,0 +1,440 @@ +/** + * DSV4 B2 — FP8 tensor-core indexer scoring + weighted ReLU + top-k. + * + * CSA Lightning Indexer (paper §2.3.1, eq. 16): + * I[t,s] = Σ_h w_h[t,h] · ReLU(q_I[t,h] · K^IComp[s]) + * + * Native Blackwell FP8 tensor-core path for decode (T=1): + * 1. Quantize Q (n_ih=64, ihd=128) BF16 → FP8_E4M3 with per-row FP32 scale + * 2. FP8 GEMM via tcgen05.mma.kind::f8f6f4: + * Q (128, 128 padded) × K^T (128, n_comp tiled by 128) → (64, n_comp) logits + * 3. Dequant GEMM output: logit[h,c] *= q_scale[h] * k_scale[kv_start+c] + * 4. ReLU, then weighted sum: score[c] = Σ_h w_h[h] * relu(logit[h,c]) + * 5. Top-k selection from (n_comp,) scores + * + * Specialized for DSV4 Pro: n_ih=64, ihd=128, top_k=1024. + * + * TMEM read strategy for 64 Q rows: + * Use tcgen05.ld.16x256b.x1 (proven in B1 FMHA) — one column per instruction. + * Lane i reads rows 4i..4i+3 from the column. Lanes 0-15 cover rows 0-63. + * 128 reads per K-tile to cover all N-dimension columns. + * + * NO PyTorch fallback. NO FP32 einsum on CUDA cores. NO BF16 workarounds. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static constexpr float E4M3_MAX = 448.0f; +static constexpr int NTHREADS = 192; +static constexpr int NWARPS = 6; +typedef unsigned short bf16_t; + +// ---- PTX helpers ---- +__device__ __forceinline__ bf16_t f32_to_bf16_ptx(float f) { + bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h; +} +__device__ __forceinline__ float bf16_to_f32_ptx(bf16_t h) { + float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f; +} +__device__ __forceinline__ uint8_t fp8_e4m3_from_f32(float x) { + x = fminf(fmaxf(x, -E4M3_MAX), E4M3_MAX); + __nv_fp8_e4m3 v(x); + return *reinterpret_cast(&v); +} + +// ---- UMMA helpers (from fmha_umma_desc.cuh, replicated for ATen build) ---- +__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; } + +__device__ __forceinline__ uint64_t make_umma_desc_kmajor_none(uint32_t smem_addr, int block_mn) { + const uint64_t LBO = block_mn * 16; + const uint64_t SBO = 128; + uint64_t desc = 0; + desc |= desc_encode(smem_addr) & 0x3FFF; + desc |= (desc_encode(LBO) & 0x3FFF) << 16; + desc |= (desc_encode(SBO) & 0x3FFF) << 32; + desc |= 1ULL << 46; + return desc; +} + +__device__ __forceinline__ uint32_t make_idesc_f8_e4m3(int block_m, int block_n) { + return (1U << 4) | ((uint32_t)(block_n >> 3) << 17) | ((uint32_t)(block_m >> 4) << 24); +} + +__device__ void umma_ss_f8f6f4(uint32_t tmem_c, uint64_t desc_a, uint64_t desc_b, + uint32_t i_desc, bool accumulate) { + uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u; + asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p;\n\t}" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(i_desc), "r"(scaleC_bits)); +} + +__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) { + asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + :: "r"(smem_ptr), "r"(num_cols)); +} +__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) { + asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;" + :: "r"(tmem_ptr), "r"(num_cols)); +} + +// ---- FP8 canonical SMEM layout (same as B1 FMHA) ---- +__device__ __forceinline__ int canon_idx_fp8_128x32(int r, int c) { + int core_mn = r >> 3; int core_k = c >> 4; + int local_r = r & 7; int local_c = c & 15; + return core_k * 16 * 128 + core_mn * 128 + local_r * 16 + local_c; +} + +// ---- Top-k (proven from indexer_score_topk.cu) ---- +#ifndef INDEXER_LOCAL_K +#define INDEXER_LOCAL_K 8 +#endif + +__device__ __forceinline__ void local_heap_insert(float* scores, int32_t* blocks, + float score, int32_t block_id, int k) { + if (score <= scores[0]) return; + scores[0] = score; blocks[0] = block_id; + int root = 0; + while (root < (k >> 1)) { + int left = 2*root+1, right = 2*root+2, smallest = root; + if (left < k && scores[left] < scores[smallest]) smallest = left; + if (right < k && scores[right] < scores[smallest]) smallest = right; + if (smallest == root) break; + float ts = scores[root]; int32_t ti = blocks[root]; + scores[root] = scores[smallest]; blocks[root] = blocks[smallest]; + scores[smallest] = ts; blocks[smallest] = ti; + root = smallest; + } +} + +__device__ __forceinline__ void heap_insert_shared(float* heap_scores, int32_t* heap_blocks, + float score, int32_t block_id, int k) { + if (score <= heap_scores[0]) return; + heap_scores[0] = score; heap_blocks[0] = block_id; + int root = 0; + while (root < (k >> 1)) { + int left = 2*root+1, right = 2*root+2, smallest = root; + if (left < k && heap_scores[left] < heap_scores[smallest]) smallest = left; + if (right < k && heap_scores[right] < heap_scores[smallest]) smallest = right; + if (smallest == root) break; + float ts = heap_scores[root]; int32_t ti = heap_blocks[root]; + heap_scores[root] = heap_scores[smallest]; heap_blocks[root] = heap_blocks[smallest]; + heap_scores[smallest] = ts; heap_blocks[smallest] = ti; + root = smallest; + } +} + +// =========================================================================== +// Kernel +// =========================================================================== + +template +__global__ void __launch_bounds__(192) +indexer_fp8_score_topk_kernel( + const bf16_t* __restrict__ q_bf16, // (n_ih, ihd) BF16 row-major + const uint8_t* __restrict__ k_fp8, // (n_comp, ihd) FP8_E4M3 + const float* __restrict__ k_scale, // (n_comp,) FP32 + const bf16_t* __restrict__ w_h_bf16, // (n_ih,) BF16 + int32_t* __restrict__ topk_indices, // (top_k,) output + int n_comp, int n_ih, int ihd, int top_k +) { + constexpr int MMA_K_F8 = 32; + constexpr int NKT = 4; // ihd=128 / MMA_K_F8=32 + constexpr int TILE_F8 = 128 * 32; // 4096 bytes per SMEM tile + constexpr int TMEM_COLS = 128; + + const int tid = threadIdx.x; + const int wid = tid >> 5; + const int lane = tid & 31; + const bool is_mma_warp = (wid == 4); + + // ---- SMEM layout ---- + extern __shared__ __align__(128) char sbuf[]; + size_t off = 0; + uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4; + off = (off + 127) & ~(size_t)127; + + // FP8 SMEM tiles for Q and K (canonical layout, 128×32 each) + uint8_t* sQ8 = (uint8_t*)(sbuf + off); off += TILE_F8; + off = (off + 127) & ~(size_t)127; + uint8_t* sK8 = (uint8_t*)(sbuf + off); off += TILE_F8; + off = (off + 127) & ~(size_t)127; + + // Per-row Q FP8 scales (n_ih, padded to 128 for alignment) + float* sQ_scale = (float*)(sbuf + off); off += 128 * sizeof(float); + off = (off + 127) & ~(size_t)127; + + // w_h in FP32 (n_ih) + float* sW_h = (float*)(sbuf + off); off += n_ih * sizeof(float); + off = (off + 127) & ~(size_t)127; + + // Merge buffer for top-k: scores (top_k floats) + indices (top_k ints) + float* sMergeScores = (float*)(sbuf + off); off += top_k * sizeof(float); + int32_t* sMergeBlocks = (int32_t*)(sbuf + off); off += top_k * sizeof(int32_t); + + // Per-thread candidates for merge + float* sCandScores = (float*)(sbuf + off); off += NTHREADS * INDEXER_LOCAL_K * sizeof(float); + int32_t* sCandBlocks = (int32_t*)(sbuf + off); off += NTHREADS * INDEXER_LOCAL_K * sizeof(int32_t); + + // ---- Per-thread local top-k ---- + float local_scores[INDEXER_LOCAL_K]; + int32_t local_blocks[INDEXER_LOCAL_K]; + for (int i = 0; i < INDEXER_LOCAL_K; i++) { + local_scores[i] = -INFINITY; + local_blocks[i] = -1; + } + + // ---- Init SMEM ---- + for (int i = tid; i < 128; i += NTHREADS) sQ_scale[i] = 0.0f; + for (int i = tid; i < n_ih; i += NTHREADS) sW_h[i] = bf16_to_f32_ptx(w_h_bf16[i]); + __syncthreads(); + + // ---- Phase 0: Compute per-row Q amax and quantize ---- + // Q is (n_ih, ihd) BF16 in GMEM. Each row gets its own FP8 scale. + // All threads cooperate on each row (one row at a time for simplicity). + for (int h = 0; h < n_ih; h++) { + float local_max = 0.0f; + for (int d = tid; d < ihd; d += NTHREADS) { + float val = fabsf(bf16_to_f32_ptx(q_bf16[h * ihd + d])); + local_max = fmaxf(local_max, val); + } + // Warp-level reduce + for (int o = 16; o > 0; o >>= 1) + local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, o)); + __shared__ float _q_amax[6]; + if ((tid & 31) == 0) _q_amax[tid >> 5] = local_max; + __syncthreads(); + float amax = 0.0f; + if (tid < 32) { + amax = (tid < 6) ? _q_amax[tid] : 0.0f; + for (int o = 16; o > 0; o >>= 1) + amax = fmaxf(amax, __shfl_down_sync(0xffffffff, amax, o)); + } + amax = __shfl_sync(0xffffffff, amax, 0); + float scale = amax / E4M3_MAX; + if (scale < 1e-8f) scale = 1e-8f; + if (tid == 0) sQ_scale[h] = scale; + // Don't write Q to SMEM yet — we'll do it per-MMA K-slice + } + __syncthreads(); + + // ---- TMEM alloc ---- + if (is_mma_warp) tmem_alloc((uint32_t)__cvta_generic_to_shared(sTmemBase), TMEM_COLS); + asm volatile("fence.proxy.async.shared::cta;" ::: "memory"); + __syncthreads(); + uint32_t tb = *sTmemBase; + + // ---- Phase 1: FP8 GEMM — Q × K^T → logits (n_ih, n_comp) ---- + const int n_k_tiles = (n_comp + SK_TILE - 1) / SK_TILE; + const uint32_t idesc_f8 = make_idesc_f8_e4m3(128, 128); + + for (int kv_tile = 0; kv_tile < n_k_tiles; kv_tile++) { + const int kv_start = kv_tile * SK_TILE; + const int kv_len = min(SK_TILE, n_comp - kv_start); + + for (int kt = 0; kt < NKT; kt++) { + // Zero SMEM tiles + for (int i = tid; i < TILE_F8; i += NTHREADS) { sQ8[i] = 0; sK8[i] = 0; } + __syncthreads(); + + // Load Q rows 0..n_ih-1, columns kt*32..kt*32+31 into sQ8 canonical + for (int i = tid; i < n_ih * MMA_K_F8; i += NTHREADS) { + int row = i / MMA_K_F8; + int col = i % MMA_K_F8; + int d = kt * MMA_K_F8 + col; + if (d < ihd) { + float val = bf16_to_f32_ptx(q_bf16[row * ihd + d]); + float inv_scale = 1.0f / sQ_scale[row]; + sQ8[canon_idx_fp8_128x32(row, col)] = fp8_e4m3_from_f32(val * inv_scale); + } + } + + // Load K rows 0..kv_len-1, columns kt*32..kt*32+31 into sK8 canonical + for (int i = tid; i < kv_len * MMA_K_F8; i += NTHREADS) { + int row = i / MMA_K_F8; + int col = i % MMA_K_F8; + int d = kt * MMA_K_F8 + col; + int g_row = kv_start + row; + sK8[canon_idx_fp8_128x32(row, col)] = k_fp8[(int64_t)g_row * ihd + d]; + } + __syncthreads(); + + // MMA + if (is_mma_warp && lane == 0) { + uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ8), 128); + uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK8), 128); + umma_ss_f8f6f4(tb, dq, dk, idesc_f8, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // ---- Read TMEM results ---- + // We need rows 0..n_ih-1 (64 rows) × SK_TILE (128 columns) from TMEM. + // Using tcgen05.ld.16x256b.x1: lane i reads rows 4i..4i+3 from one column. + // Lanes 0..15 cover rows 0..63. Lanes 16..31 cover rows 64..127 (ignored). + // + // Process logits on-the-fly: dequant, ReLU, weighted sum, top-k update. + // No SMEM staging of the full logits matrix needed. + // + // Parallel read: warps 0-3 each read 32 columns (128/4=32), processing + // independently. Each warp computes the weighted ReLU sum for its columns + // and updates per-thread local top-k. + + const int COLS_PER_WARP = SK_TILE / 4; // 32 + int my_warp = wid; + if (my_warp < 4) { + int col_start = my_warp * COLS_PER_WARP; + int col_end = col_start + COLS_PER_WARP; + + for (int c = col_start; c < col_end; c++) { + if (c >= kv_len) break; + + // Read column c from TMEM + uint32_t r0, r1, r2, r3; + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(tb + c)); + asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory"); + + float f0, f1, f2, f3; + memcpy(&f0, &r0, 4); memcpy(&f1, &r1, 4); + memcpy(&f2, &r2, 4); memcpy(&f3, &r3, 4); + + // Lane i processes rows 4i..4i+3 for this column + if (lane < (n_ih + 3) / 4) { + float vals[4] = {f0, f1, f2, f3}; + float k_s = k_scale[kv_start + c]; + + float weighted_relu_sum = 0.0f; + for (int j = 0; j < 4; j++) { + int h = lane * 4 + j; + if (h < n_ih) { + float logit = vals[j] * sQ_scale[h] * k_s; + if (logit > 0.0f) { + weighted_relu_sum += sW_h[h] * logit; + } + } + } + // Sum across lanes 0..15 within this warp + if (lane >= 16) weighted_relu_sum = 0.0f; + for (int o = 16; o > 0; o >>= 1) + weighted_relu_sum += __shfl_down_sync(0xffffffff, weighted_relu_sum, o); + if (lane == 0 && weighted_relu_sum > 0.0f) { + int c_global = kv_start + c; + local_heap_insert(local_scores, local_blocks, weighted_relu_sum, c_global, INDEXER_LOCAL_K); + } + } + } + } + __syncthreads(); + } + + // ---- TMEM dealloc ---- + if (is_mma_warp) tmem_dealloc(tb, TMEM_COLS); + __syncthreads(); + + // ---- Phase 2: Block-level top-k merge ---- + // Each thread writes its INDEXER_LOCAL_K candidates to SMEM, then + // one thread builds the final top-k. + + for (int i = tid; i < top_k; i += NTHREADS) { + sMergeScores[i] = -INFINITY; + sMergeBlocks[i] = -1; + } + int my_offset = tid * INDEXER_LOCAL_K; + for (int i = 0; i < INDEXER_LOCAL_K; i++) { + sCandScores[my_offset + i] = local_scores[i]; + sCandBlocks[my_offset + i] = local_blocks[i]; + } + __syncthreads(); + + if (tid == 0) { + for (int i = 0; i < NTHREADS * INDEXER_LOCAL_K; i++) { + if (sCandScores[i] > -INFINITY) { + heap_insert_shared(sMergeScores, sMergeBlocks, + sCandScores[i], sCandBlocks[i], top_k); + } + } + } + __syncthreads(); + + // ---- Write top-k indices sorted by score ---- + if (tid == 0) { + for (int i = 0; i < top_k; i++) { + int best = i; + for (int j = i + 1; j < top_k; j++) { + if (sMergeScores[j] > sMergeScores[best]) best = j; + } + if (best != i) { + float ts = sMergeScores[i]; int32_t ti = sMergeBlocks[i]; + sMergeScores[i] = sMergeScores[best]; sMergeBlocks[i] = sMergeBlocks[best]; + sMergeScores[best] = ts; sMergeBlocks[best] = ti; + } + topk_indices[i] = sMergeBlocks[i]; + } + } +} + +// =========================================================================== +// PyTorch binding +// =========================================================================== + +void indexer_fp8_score_topk_cuda( + torch::Tensor q_bf16, // (n_ih, ihd) BF16 + torch::Tensor k_fp8, // (n_comp, ihd) uint8/float8_e4m3fn + torch::Tensor k_scale, // (n_comp,) FP32 + torch::Tensor w_h, // (n_ih,) BF16 + torch::Tensor topk_indices, // (top_k,) int32 output + int64_t n_ih, int64_t ihd, int64_t top_k +) { + TORCH_CHECK(q_bf16.is_cuda() && q_bf16.scalar_type() == torch::kBFloat16); + TORCH_CHECK(k_fp8.is_cuda()); + TORCH_CHECK(k_scale.is_cuda() && k_scale.scalar_type() == torch::kFloat32); + TORCH_CHECK(w_h.is_cuda() && w_h.scalar_type() == torch::kBFloat16); + + int n_comp = k_fp8.size(0); + + // Convert k_fp8 to uint8 view if needed + auto k8 = k_fp8.dtype() == torch::kUInt8 ? k_fp8 : k_fp8.view(torch::kUInt8); + + // SMEM size calculation + size_t smem = 0; + smem += 4; smem = (smem + 127) & ~127; // sTmemBase + smem += 128 * 32; smem = (smem + 127) & ~127; // sQ8 + smem += 128 * 32; smem = (smem + 127) & ~127; // sK8 + smem += 128 * 4; smem = (smem + 127) & ~127; // sQ_scale + smem += n_ih * 4; smem = (smem + 127) & ~127; // sW_h + // sLogits not needed — on-the-fly processing during TMEM read + smem += top_k * 4; // sMergeScores + smem += top_k * 4; // sMergeBlocks + smem += 192 * INDEXER_LOCAL_K * 4; // sCandScores + smem += 192 * INDEXER_LOCAL_K * 4; // sCandBlocks + + cudaFuncSetAttribute(indexer_fp8_score_topk_kernel<128>, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + + indexer_fp8_score_topk_kernel<128><<<1, 192, smem, c10::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(q_bf16.data_ptr()), + k8.data_ptr(), + k_scale.data_ptr(), + reinterpret_cast(w_h.data_ptr()), + topk_indices.data_ptr(), + n_comp, (int)n_ih, (int)ihd, (int)top_k); + + C10_CUDA_CHECK(cudaGetLastError()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("indexer_fp8_score_topk", &indexer_fp8_score_topk_cuda, + "B2 FP8 tensor-core indexer scoring + weighted ReLU + top-k"); +} diff --git a/single_shot_inference.py b/single_shot_inference.py index b09dd409..36866850 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -406,40 +406,64 @@ class Indexer: self.compressor = Compressor(4, self.ihd, 7168, dev) self.compressor.load(w, pfx, dev) - def forward(self, q_lora, hidden_states, comp_indexer_kv, positions, layer_idx=None): - if self.q_b_lin is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0: + def forward(self, q_lora, hidden_states, kv_cache, positions, layer_idx=None): + """B2 FP8 tensor-core indexer scoring + weighted ReLU + top-k. + + Pipeline: + 1. NVFP4 GEMM: q_a (lora) @ q_b_proj → (T, n_ih * ihd) BF16 + 2. NVFP4 GEMM: hidden_states @ weights_proj → (T, n_ih) BF16 + 3. FP8 GEMM + ReLU + weighted sum + top-k (CUDA kernel) + + Indexer keys are consumed directly in FP8_E4M3 format — no BF16 dequant. + """ + if self.q_b_lin is None or kv_cache is None or not kv_cache._has_idx or kv_cache.n_comp == 0: return None - dev = q_lora.device; T = q_lora.shape[0]; n_comp = comp_indexer_kv.shape[0] - # INDEXER PROBE: print shapes at layer_idx==0 only + dev = q_lora.device; T = q_lora.shape[0] li = layer_idx - if li == 0: - print(f"\n=== INDEXER PROBE L0 ===", flush=True) - print(f" q_lora: shape={tuple(q_lora.shape)} dtype={q_lora.dtype}", flush=True) - print(f" comp_idx_kv: shape={tuple(comp_indexer_kv.shape)} " - f"dtype={comp_indexer_kv.dtype} stride={comp_indexer_kv.stride()} " - f"contig={comp_indexer_kv.is_contiguous()}", flush=True) - print(f" self.n_ih={self.n_ih} self.ihd={self.ihd} n_ih*ihd={self.n_ih * self.ihd}", flush=True) - print(f" self.q_b_lin.in_features={self.q_b_lin.in_features} out_features={self.q_b_lin.out_features}", flush=True) - print(f" self.wp_lin.in_features={self.wp_lin.in_features} out_features={self.wp_lin.out_features}", flush=True) - if self.compressor is not None: - print(f" self.compressor.kv_dim={self.compressor.kv_dim} ratio={self.compressor.ratio} hd={self.compressor.hd}", flush=True) + q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) # (T, n_ih, ihd) w_h = self.wp_lin(hidden_states) # (T, n_ih) - # Stored indexer keys are (n_comp, ihd) — one vector per compressed block, - # shared across all indexer heads (paper's c_I = ihd = 128). - # NOT (n_comp, n_ih, ihd) — there is no per-head key decomposition. - k_idx = comp_indexer_kv # (n_comp, ihd) + + # B2: FP8 tensor-core scoring path. + # Indexer keys are stored as FP8_E4M3 in the KV cache. + # No BF16 dequantization — the CUDA kernel consumes FP8 directly. + k_fp8 = kv_cache.comp_idx_fp8[:kv_cache.n_comp] # (n_comp, ihd) uint8 + k_scale = kv_cache.comp_idx_scale[:kv_cache.n_comp] # (n_comp,) FP32 + n_comp = kv_cache.n_comp + if li == 0: - print(f"--- INDEXER L0 SCORING TENSORS ---", flush=True) + print(f"\n=== INDEXER PROBE L0 (B2 FP8) ===", flush=True) print(f" q_idx: shape={tuple(q_idx.shape)} dtype={q_idx.dtype}", flush=True) - print(f" k_idx: shape={tuple(k_idx.shape)} dtype={k_idx.dtype}", flush=True) + print(f" k_fp8: shape={tuple(k_fp8.shape)} dtype={k_fp8.dtype}", flush=True) + print(f" k_scale: shape={tuple(k_scale.shape)} dtype={k_scale.dtype}", flush=True) print(f" w_h: shape={tuple(w_h.shape)} dtype={w_h.dtype}", flush=True) - # Weighted ReLU MQA scoring (eq. 16): - # score(t, c) = sum_h w_h(t,h) * ReLU(q(t,h) · k(c)) - # k is shared across heads: einsum 'tnd,cd->tnc' (c=n_comp, d=ihd) - scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float()) # (T, n_ih, n_comp) + + # For T=1 decode: use the B2 FP8 CUDA kernel + if T == 1 and self.ihd == 128 and self.n_ih == 64: + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("indexer_fp8_score_topk", ["indexer_fp8_score_topk.cu"], + extra_cuda_cflags=[ + "-gencode=arch=compute_100a,code=sm_100a", + "-O3", "--use_fast_math", "--expt-relaxed-constexpr", + ]) + q_2d = q_idx.squeeze(0).contiguous() # (n_ih, ihd) BF16 + w_1d = w_h.squeeze(0).contiguous() # (n_ih,) BF16 + tk = min(self.top_k, n_comp) + topk_indices = torch.empty(tk, dtype=torch.int32, device=dev) + mod.indexer_fp8_score_topk( + q_2d, k_fp8, k_scale, w_1d, topk_indices, + self.n_ih, self.ihd, tk) + return topk_indices.unsqueeze(0) # (1, top_k) + + # Fallback for T>1 or non-standard dimensions — FP32 einsum + k_idx = k_fp8 # still FP8, need dequant for einsum + if k_idx.dtype == torch.uint8 or str(k_idx.dtype) == 'torch.float8_e4m3fn': + from dsv4.kernels.cuda.loader import get_cuda_module + kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) + k_idx = kv_mod.dequant_fp8_e4m3(k_fp8, k_scale) # (n_comp, ihd) BF16 + scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float()) scores = F.relu(scores) - total = (scores * w_h.unsqueeze(-1).float()).sum(1) # (T, n_comp) + total = (scores * w_h.unsqueeze(-1).float()).sum(1) tk = min(self.top_k, n_comp); _, idx = total.topk(tk, -1); return idx # ===================================================================== @@ -834,7 +858,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, # 4. Indexer top-k (CSA) topk_idx = None if indexer is not None and ratio == 4: - topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li) + topk_idx = indexer.forward(q_a, x_normed, kv_cache, positions, layer_idx=li) # 5. Gather KV — B1 storage-native mixed path. # noPE remains FP8_E4M3 + per-row scale; RoPE remains BF16.