From a75a9843af15e5b6a9e7315ee381de00737df2fc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 00:59:06 +0000 Subject: [PATCH] Fix B2 indexer: add sLogits scratch buffer to SMEM layout --- dsv4/kernels/cuda/indexer_fp8_score_topk.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dsv4/kernels/cuda/indexer_fp8_score_topk.cu b/dsv4/kernels/cuda/indexer_fp8_score_topk.cu index 7929fe99..e68fb21b 100644 --- a/dsv4/kernels/cuda/indexer_fp8_score_topk.cu +++ b/dsv4/kernels/cuda/indexer_fp8_score_topk.cu @@ -176,6 +176,9 @@ indexer_fp8_score_topk_kernel( float* sW_h = (float*)(sbuf + off); off += n_ih * sizeof(float); off = (off + 127) & ~(size_t)127; + // Scratch buffer for cross-warp score accumulation (SK_TILE floats) + float* sLogits = (float*)(sbuf + off); off += SK_TILE * sizeof(float); + // Merge buffer for top-k: scores (top_k floats) + indices (top_k ints) float* sMergeScores = (float*)(sbuf + off); off += top_k * sizeof(float); int32_t* sMergeBlocks = (int32_t*)(sbuf + off); off += top_k * sizeof(int32_t); @@ -431,7 +434,7 @@ void indexer_fp8_score_topk_cuda( smem += 128 * 32; smem = (smem + 127) & ~127; // sK8 smem += 128 * 4; smem = (smem + 127) & ~127; // sQ_scale smem += n_ih * 4; smem = (smem + 127) & ~127; // sW_h - // sLogits not needed — on-the-fly processing during TMEM read + smem += 128 * 4; // sLogits scratch (SK_TILE=128) smem += top_k * 4; // sMergeScores smem += top_k * 4; // sMergeBlocks smem += 192 * INDEXER_LOCAL_K * 4; // sCandScores