fix: SMEM calc in decode test

This commit is contained in:
2026-05-28 23:08:54 +00:00
parent 58ff781388
commit f2124b9378

View File

@@ -197,6 +197,8 @@ int main() {
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
constexpr int TILE_SZ = 128 * MMA_K_BF16;
constexpr int V_SUB_SZ = 16 * MMA_K_BF16;
size_t smem_off = 8 + 128*4 + 128*4;
smem_off = ((smem_off + 127) & ~(size_t)127);
smem_off += TILE_SZ + TILE_SZ;