fix: correct SMEM size calculation in multirow test

This commit is contained in:
2026-05-28 22:53:46 +00:00
parent 863a030c3b
commit 4cfb707405

View File

@@ -120,9 +120,21 @@ static int test_single_T(int T, int n_h = 1, int batch = 1) {
params.lse_head_stride = T;
params.lse_batch_stride = n_h * T;
// SMEM: sTmemBase(8) + sRowMax(512) + sRowSum(512) + padding(128) + sQ0(4096) + sK0(4096) + sPk(4096) + sV(512) + slack(256)
// NO s_p_vals buffer — P is streamed per K-tile through sPk
int smem = 8 + MAX_T*4 + MAX_T*4 + 128 + TILE_SZ*2 + TILE_SZ*2 + V_SUB_SZ*2 + 256;
// SMEM: must match kernel's layout exactly.
// sTmemBase(8) + sRowMax(512) + sRowSum(512) + align128 + sQ0(4096) + sK0(4096) + align128 + sPk(4096) + align128 + sV(512) + slack
// Pessimistic: just compute with alignment
size_t smem_off = 0;
smem_off += 8; // sTmemBase + alignment
smem_off += 128 * sizeof(float) * 2; // sRowMax + sRowSum
smem_off = (smem_off + 127) & ~(size_t)127; // align for sQ0
smem_off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sQ0
smem_off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sK0
smem_off = (smem_off + 127) & ~(size_t)127; // align for sPk
smem_off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sPk
smem_off = (smem_off + 127) & ~(size_t)127; // align for sV
smem_off += 16 * MMA_K_BF16 * sizeof(bf16_t); // sV
smem_off += 256; // slack
int smem = (int)smem_off;
smem = (smem + 127) & ~127;
if (smem > 48 * 1024) {