Fix B2 indexer: add sLogits scratch buffer to SMEM layout

This commit is contained in:
2026-06-03 00:59:06 +00:00
parent cc7b17fdaa
commit a75a9843af

View File

@@ -176,6 +176,9 @@ indexer_fp8_score_topk_kernel(
float* sW_h = (float*)(sbuf + off); off += n_ih * sizeof(float);
off = (off + 127) & ~(size_t)127;
// Scratch buffer for cross-warp score accumulation (SK_TILE floats)
float* sLogits = (float*)(sbuf + off); off += SK_TILE * sizeof(float);
// Merge buffer for top-k: scores (top_k floats) + indices (top_k ints)
float* sMergeScores = (float*)(sbuf + off); off += top_k * sizeof(float);
int32_t* sMergeBlocks = (int32_t*)(sbuf + off); off += top_k * sizeof(int32_t);
@@ -431,7 +434,7 @@ void indexer_fp8_score_topk_cuda(
smem += 128 * 32; smem = (smem + 127) & ~127; // sK8
smem += 128 * 4; smem = (smem + 127) & ~127; // sQ_scale
smem += n_ih * 4; smem = (smem + 127) & ~127; // sW_h
// sLogits not needed — on-the-fly processing during TMEM read
smem += 128 * 4; // sLogits scratch (SK_TILE=128)
smem += top_k * 4; // sMergeScores
smem += top_k * 4; // sMergeBlocks
smem += 192 * INDEXER_LOCAL_K * 4; // sCandScores