fix: correct SMEM size calculation in multirow test
This commit is contained in:
@@ -120,9 +120,21 @@ static int test_single_T(int T, int n_h = 1, int batch = 1) {
|
||||
params.lse_head_stride = T;
|
||||
params.lse_batch_stride = n_h * T;
|
||||
|
||||
// SMEM: sTmemBase(8) + sRowMax(512) + sRowSum(512) + padding(128) + sQ0(4096) + sK0(4096) + sPk(4096) + sV(512) + slack(256)
|
||||
// NO s_p_vals buffer — P is streamed per K-tile through sPk
|
||||
int smem = 8 + MAX_T*4 + MAX_T*4 + 128 + TILE_SZ*2 + TILE_SZ*2 + V_SUB_SZ*2 + 256;
|
||||
// SMEM: must match kernel's layout exactly.
|
||||
// sTmemBase(8) + sRowMax(512) + sRowSum(512) + align128 + sQ0(4096) + sK0(4096) + align128 + sPk(4096) + align128 + sV(512) + slack
|
||||
// Pessimistic: just compute with alignment
|
||||
size_t smem_off = 0;
|
||||
smem_off += 8; // sTmemBase + alignment
|
||||
smem_off += 128 * sizeof(float) * 2; // sRowMax + sRowSum
|
||||
smem_off = (smem_off + 127) & ~(size_t)127; // align for sQ0
|
||||
smem_off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sQ0
|
||||
smem_off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sK0
|
||||
smem_off = (smem_off + 127) & ~(size_t)127; // align for sPk
|
||||
smem_off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sPk
|
||||
smem_off = (smem_off + 127) & ~(size_t)127; // align for sV
|
||||
smem_off += 16 * MMA_K_BF16 * sizeof(bf16_t); // sV
|
||||
smem_off += 256; // slack
|
||||
int smem = (int)smem_off;
|
||||
smem = (smem + 127) & ~127;
|
||||
|
||||
if (smem > 48 * 1024) {
|
||||
|
||||
Reference in New Issue
Block a user