Fix B2 indexer: add sLogits scratch buffer to SMEM layout
This commit is contained in:
@@ -176,6 +176,9 @@ indexer_fp8_score_topk_kernel(
|
||||
float* sW_h = (float*)(sbuf + off); off += n_ih * sizeof(float);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
|
||||
// Scratch buffer for cross-warp score accumulation (SK_TILE floats)
|
||||
float* sLogits = (float*)(sbuf + off); off += SK_TILE * sizeof(float);
|
||||
|
||||
// Merge buffer for top-k: scores (top_k floats) + indices (top_k ints)
|
||||
float* sMergeScores = (float*)(sbuf + off); off += top_k * sizeof(float);
|
||||
int32_t* sMergeBlocks = (int32_t*)(sbuf + off); off += top_k * sizeof(int32_t);
|
||||
@@ -431,7 +434,7 @@ void indexer_fp8_score_topk_cuda(
|
||||
smem += 128 * 32; smem = (smem + 127) & ~127; // sK8
|
||||
smem += 128 * 4; smem = (smem + 127) & ~127; // sQ_scale
|
||||
smem += n_ih * 4; smem = (smem + 127) & ~127; // sW_h
|
||||
// sLogits not needed — on-the-fly processing during TMEM read
|
||||
smem += 128 * 4; // sLogits scratch (SK_TILE=128)
|
||||
smem += top_k * 4; // sMergeScores
|
||||
smem += top_k * 4; // sMergeBlocks
|
||||
smem += 192 * INDEXER_LOCAL_K * 4; // sCandScores
|
||||
|
||||
Reference in New Issue
Block a user