fix: correct SMEM size for row-major (not swizzled)

This commit is contained in:
2026-05-28 08:44:55 +00:00
parent c64bd7b875
commit 2a765be715

View File

@@ -64,14 +64,8 @@ int main() {
cudaMemcpy(dk, hkb, B*s_k*HD*2, cudaMemcpyHostToDevice);
cudaMemset(ds_out, 0, s_k*4);
// SMEM = 4 (tmem_base) + 128B align + 128*HD*2*2 (sQ + sK, swizzled) + slack
// Swizzled layout may use more space than row-major due to atom padding
// MN_SW128 atom = 1024*8 = 8192 BF16 per atom. For (128, HD) = 128*HD BF16.
// With 1 MN tile and ceil(HD/8) K tiles, total = 8192 * ceil(HD/8) BF16
int atoms_n = (HD + 7) / 8; // number of K-tiles (atom has 8 BF16 in K dim)
int smem_q = 1024 * 8 * atoms_n; // BF16 elements for Q (padded)
int smem_k = 1024 * 8 * atoms_n; // BF16 elements for K (padded)
int smem = 4 + 128 + smem_q * 2 + smem_k * 2 + 1024; // bytes
// SMEM = 4 (tmem_base) + 128B align + 128*HD*2 (sQ) + 128*HD*2 (sK) + slack
int smem = 4 + 128 + 128 * HD * 2 + 128 * HD * 2 + 4096;
dim3 grid(1, H, B);
dim3 block(NTHREADS);