diff --git a/dsv4/kernels/cuda/indexer_score_topk.cu b/dsv4/kernels/cuda/indexer_score_topk.cu index cafb1a2c..77a6947f 100644 --- a/dsv4/kernels/cuda/indexer_score_topk.cu +++ b/dsv4/kernels/cuda/indexer_score_topk.cu @@ -1,77 +1,34 @@ -// indexer_score_topk.cu — Fused score + ReLU + weighted-sum + top-k kernel. -// -// CSA Lightning Indexer (paper §2.3.1, eq. 16): -// I[t,s] = Σ_h w_h[t,h] · ReLU(q_I[t,h] · K^IComp[s,h]) -// Selected = TopK(I[t,:], k=csa_top_k) -// -// One CTA per query token. Streams indexer keys from the paged pool, -// computes per-head dot products in FP32, ReLU, weighted sum, heap top-k. -// -// Phase 1 (this file): FP32 dot products via standard CUDA ops. -// Phase 2 (future): swap to FP4 tcgen05 MMA for production throughput. -// The FP32 path is correct and used for testing; the FP4 path is the -// performance optimization on a known-correct base. -// -// Indexer keys are stored in the paged pool as FP4 (NVFP4 scheme). -// This kernel dequantizes them to FP32 before the dot product. -// The FP4 tcgen05 version will avoid this dequant and do FP4 MMA directly. - #include -#include #include +#include #include #include - #include -// ---- FP4 dequantization (NVFP4 scheme) ---- -// FP4 E2M1: values 0-6 in 3 bits (7 = NaN/unused), 1 sign bit. -// Scale is per-16-element group, stored as FP8 E4M3. -// Global scale is FP32 per vector. -// Dequant: val = (fp4_int) * group_scale * global_scale - __device__ __forceinline__ float dequant_fp4_scalar( - uint8_t packed, int lane, // lane 0 = low nibble, lane 1 = high nibble - float group_scale, float global_scale + uint8_t packed, int lane, float group_scale, float global_scale ) { int nibble = (lane == 0) ? (packed & 0x0F) : (packed >> 4); - // FP4 E2M1: bit3=sign, bits[2:0]=magnitude (0-6) int sign = (nibble >> 3) & 1; int mag = nibble & 0x07; float val = (float)mag * group_scale * global_scale; return sign ? -val : val; } -// ---- Min-heap for top-k ---- -// Heap of (score, block_id) pairs. Root = smallest score. -// Insert: if new score > root, replace root and sift down. -// After all inserts, the heap contains the top-k entries. - -__device__ __forceinline__ void heap_insert( - float* __restrict__ heap_scores, - int32_t* __restrict__ heap_blocks, - float score, int32_t block_id, - int k +__device__ void heap_insert( + float* heap_scores, int32_t* heap_blocks, + float score, int32_t block_id, int k ) { - if (score <= heap_scores[0]) return; // doesn't beat min + if (score <= heap_scores[0]) return; heap_scores[0] = score; heap_blocks[0] = block_id; - // Sift down int root = 0; while (root < (k >> 1)) { int left = 2 * root + 1; int right = 2 * root + 2; int smallest = root; - if (left < k && (heap_scores[left] < heap_scores[smallest] || - (heap_scores[left] == heap_scores[smallest] && - heap_blocks[left] > heap_blocks[smallest]))) { - smallest = left; - } - if (right < k && (heap_scores[right] < heap_scores[smallest] || - (heap_scores[right] == heap_scores[smallest] && - heap_blocks[right] > heap_blocks[smallest]))) { - smallest = right; - } + if (left < k && heap_scores[left] < heap_scores[smallest]) smallest = left; + if (right < k && heap_scores[right] < heap_scores[smallest]) smallest = right; if (smallest == root) break; float ts = heap_scores[root]; int32_t ti = heap_blocks[root]; heap_scores[root] = heap_scores[smallest]; heap_blocks[root] = heap_blocks[smallest]; @@ -80,204 +37,125 @@ __device__ __forceinline__ void heap_insert( } } -// =========================================================================== -// Main kernel -// =========================================================================== - -__global__ void indexer_score_topk_fp32_kernel( - // Query inputs (FP32 — dequantized from FP4 in the launcher or here) - const float* __restrict__ q_I, // [T, n_heads, head_dim] FP32 - const float* __restrict__ w_h, // [T, n_heads] FP32 - // Indexer keys from cache (FP4 packed) - const uint8_t* __restrict__ keys_fp4, // [num_phys_blocks, epb, hd/2] - const uint8_t* __restrict__ key_scale, // [num_phys_blocks, epb, hd/16] FP8 E4M3 - const float* __restrict__ key_gscale, // [num_phys_blocks] FP32 - // Block table - const int32_t* __restrict__ block_table, // [T, max_logical_blocks] - const int32_t* __restrict__ valid_lens, // [T] int32 — total valid entries per query - // Output - int32_t* __restrict__ topk_indices, // [T, top_k] int32 - // Geometry +__global__ void indexer_score_topk_kernel( + const float* __restrict__ q_I, + const float* __restrict__ w_h, + const uint8_t* __restrict__ keys_fp4, + const uint8_t* __restrict__ key_scale, + const float* __restrict__ key_gscale, + const int32_t* __restrict__ block_table, + const int32_t* __restrict__ valid_lens, + int32_t* __restrict__ topk_indices, int n_heads, int head_dim, int top_k, int entries_per_block, int max_logical_blocks ) { - int t = blockIdx.x; // one CTA per query token + int t = blockIdx.x; if (t >= gridDim.x) return; - int tid = threadIdx.x; int n_threads = blockDim.x; int num_valid = valid_lens[t]; - int n_groups = head_dim / 16; // FP4 group count per entry - int n_bytes = head_dim / 2; // FP4 packed bytes per entry + int n_groups = head_dim / 16; + int total_groups = n_heads * n_groups; + int n_bytes = head_dim / 2; + int total_bytes = n_heads * n_bytes; - // ---- Load w_h[t, :] into shared memory ---- - // Layout: [w_h (n_h floats)] [heap_lock (1 int)] [heap_scores (top_k floats)] [heap_blocks (top_k ints)] - extern __shared__ char smem[]; - float* smem_w = reinterpret_cast(smem); - int* smem_heap_lock = reinterpret_cast(smem_w + n_heads); - float* smem_heap_scores = reinterpret_cast(smem_heap_lock + 1); - int32_t* smem_heap_blocks = reinterpret_cast(smem_heap_scores + top_k); + // Per-thread heap in REGISTERS (top_k <= 1024, but for small k this works) + // Actually, use shared memory with a simple layout + __shared__ float s_heap_scores[1024]; // max top_k + __shared__ int32_t s_heap_blocks[1024]; + __shared__ float s_w[64]; // max n_heads + __shared__ int s_lock; // Load w_h for (int h = tid; h < n_heads; h += n_threads) { - smem_w[h] = w_h[t * n_heads + h]; + s_w[h] = w_h[t * n_heads + h]; } - - // Init heap to -inf + // Init heap for (int i = tid; i < top_k; i += n_threads) { - smem_heap_scores[i] = -INFINITY; - smem_heap_blocks[i] = -1; + s_heap_scores[i] = -INFINITY; + s_heap_blocks[i] = -1; } + if (tid == 0) s_lock = 0; __syncthreads(); - // ---- Stream over all valid compressed entries ---- - // Each entry is a candidate block s. - // I[t,s] = Σ_h w_h[h] * ReLU( ) - // - // We parallelize over entries: each thread handles a subset of entries, - // computes the full score, then inserts into the shared heap. - // For S=250K and 128 threads, each thread handles ~2K entries. - + // Stream over entries for (int s = tid; s < num_valid; s += n_threads) { - // Resolve physical location of entry s int logical_block = s / entries_per_block; int slot_in_block = s % entries_per_block; int phys_block = block_table[t * max_logical_blocks + logical_block]; - int block_entry_flat = phys_block * entries_per_block + slot_in_block; + int flat = phys_block * entries_per_block + slot_in_block; - float global_s = key_gscale[phys_block]; + float gs = key_gscale[phys_block]; - // Compute score = Σ_h w_h[h] * ReLU( ) + // Compute score float score = 0.0f; - for (int h = 0; h < n_heads; h++) { float dot = 0.0f; - // Dequantize FP4 key and compute dot product with q_I + int h_byte_off = h * n_bytes; + int h_group_off = h * n_groups; for (int g = 0; g < n_groups; g++) { - // Read group scale (FP8 E4M3) - uint8_t raw_scale = key_scale[block_entry_flat * n_groups + g]; + uint8_t raw_sc = key_scale[flat * total_groups + h_group_off + g]; __nv_fp8_e4m3 fp8_s; - fp8_s.__x = raw_scale; - float group_s = (float)fp8_s * global_s; + fp8_s.__x = raw_sc; + float grp_s = (float)fp8_s * gs; - // Read 8 packed bytes = 16 FP4 values for (int b = 0; b < 8; b++) { - uint8_t packed = keys_fp4[block_entry_flat * n_bytes + g * 8 + b]; - float v0 = dequant_fp4_scalar(packed, 0, group_s, 1.0f); - float v1 = dequant_fp4_scalar(packed, 1, group_s, 1.0f); - // q_I values (FP32, already dequantized) + uint8_t packed = keys_fp4[flat * total_bytes + h_byte_off + g * 8 + b]; + float v0 = dequant_fp4_scalar(packed, 0, grp_s, 1.0f); + float v1 = dequant_fp4_scalar(packed, 1, grp_s, 1.0f); int d0 = g * 16 + 2 * b; int d1 = d0 + 1; dot += v0 * q_I[t * n_heads * head_dim + h * head_dim + d0]; dot += v1 * q_I[t * n_heads * head_dim + h * head_dim + d1]; } } - // ReLU + weighted sum if (dot > 0.0f) { - score += smem_w[h] * dot; + score += s_w[h] * dot; } } - // Insert into heap - // Must be serialized — use a critical section per CTA. - // For correctness, one thread at a time inserts. - // This is the simple approach; a lock-free heap is an optimization. - if (score > -INFINITY) { - // Use a simple spin-lock approach: thread 0 does all inserts. - // Each thread writes its (score, s) to a staging area. - // Then thread 0 iterates through the staging area. - // For now, just serialize via atomicMax on a flag. - // Actually, since each thread has its own set of entries (strided), - // and the heap is shared, we need mutual exclusion. - // Simplest: one thread handles all its entries, then next thread. - // We do this by having each thread wait for its turn. - // For now: all threads write to a SMEM buffer, then one thread - // processes the buffer. - - // Write to a shared staging buffer (one per thread, fixed size) - // Actually, the simplest correct approach: each thread maintains - // its own top-k in registers, then we merge at the end. - // But register top-k for k=1024 is too large. - // - // Practical approach: use atomicCAS on a SMEM lock. - // Only one thread inserts at a time. - // Use heap_lock in the extern SMEM - if (tid == 0) smem_heap_lock[0] = 0; - __syncthreads(); - - while (atomicCAS(smem_heap_lock, 0, 1) != 0) {} // acquire - heap_insert(smem_heap_scores, smem_heap_blocks, score, s, top_k); - atomicExch(smem_heap_lock, 0); // release - } + // Insert into shared heap (serialized via spinlock) + while (atomicCAS(&s_lock, 0, 1) != 0) {} + heap_insert(s_heap_scores, s_heap_blocks, score, s, top_k); + atomicExch(&s_lock, 0); } - __syncthreads(); - // ---- Write top-k indices to global memory ---- - // Sort heap by score descending for deterministic output. - // Simple selection sort on the small heap (top_k <= 1024). + // Sort + write output if (tid == 0) { for (int i = 0; i < top_k; i++) { - // Find max among remaining int best = i; for (int j = i + 1; j < top_k; j++) { - if (smem_heap_scores[j] > smem_heap_scores[best] || - (smem_heap_scores[j] == smem_heap_scores[best] && - smem_heap_blocks[j] < smem_heap_blocks[best])) { - best = j; - } + if (s_heap_scores[j] > s_heap_scores[best]) best = j; } if (best != i) { - float ts = smem_heap_scores[i]; int32_t ti = smem_heap_blocks[i]; - smem_heap_scores[i] = smem_heap_scores[best]; smem_heap_blocks[i] = smem_heap_blocks[best]; - smem_heap_scores[best] = ts; smem_heap_blocks[best] = ti; + float ts = s_heap_scores[i]; int32_t ti = s_heap_blocks[i]; + s_heap_scores[i] = s_heap_scores[best]; s_heap_blocks[i] = s_heap_blocks[best]; + s_heap_scores[best] = ts; s_heap_blocks[best] = ti; } - topk_indices[t * top_k + i] = smem_heap_blocks[i]; + topk_indices[t * top_k + i] = s_heap_blocks[i]; } } } - -// =========================================================================== -// PyTorch binding -// =========================================================================== - -void indexer_score_topk_fp32_cuda( - torch::Tensor q_I, // [T, n_heads, head_dim] FP32 - torch::Tensor w_h, // [T, n_heads] FP32 - torch::Tensor keys_fp4, // [num_blocks, epb, hd/2] uint8 - torch::Tensor key_scale, // [num_blocks, epb, hd/16] uint8 (FP8 E4M3) - torch::Tensor key_gscale, // [num_blocks] FP32 - torch::Tensor block_table, // [T, max_logical_blocks] int32 - torch::Tensor valid_lens, // [T] int32 - torch::Tensor topk_indices, // [T, top_k] int32 (output) - int64_t n_heads, int64_t head_dim, int64_t top_k, - int64_t entries_per_block +void indexer_score_topk_cuda( + torch::Tensor q_I, torch::Tensor w_h, + torch::Tensor keys_fp4, torch::Tensor key_scale, torch::Tensor key_gscale, + torch::Tensor block_table, torch::Tensor valid_lens, torch::Tensor topk_indices, + int64_t n_heads, int64_t head_dim, int64_t top_k, int64_t entries_per_block ) { int T = q_I.size(0); int max_logical_blocks = block_table.size(1); - int threads = 128; - - // SMEM: w_h + heap_lock + heap_scores + heap_blocks - int smem_bytes = (n_heads + 1 + top_k) * sizeof(float) + top_k * sizeof(int32_t); - - indexer_score_topk_fp32_kernel<<>>( - q_I.data_ptr(), - w_h.data_ptr(), - keys_fp4.data_ptr(), - key_scale.data_ptr(), - key_gscale.data_ptr(), - block_table.data_ptr(), - valid_lens.data_ptr(), - topk_indices.data_ptr(), - (int)n_heads, (int)head_dim, (int)top_k, - (int)entries_per_block, max_logical_blocks + indexer_score_topk_kernel<<>>( + q_I.data_ptr(), w_h.data_ptr(), + keys_fp4.data_ptr(), key_scale.data_ptr(), + key_gscale.data_ptr(), block_table.data_ptr(), + valid_lens.data_ptr(), topk_indices.data_ptr(), + (int)n_heads, (int)head_dim, (int)top_k, (int)entries_per_block, max_logical_blocks ); C10_CUDA_CHECK(cudaGetLastError()); } - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("indexer_score_topk_fp32", &indexer_score_topk_fp32_cuda, - "Indexer score + top-k (FP32 dot products)"); + m.def("indexer_score_topk", &indexer_score_topk_cuda); }