fix: correct SMEM size for row-major (not swizzled)
This commit is contained in:
@@ -64,14 +64,8 @@ int main() {
|
||||
cudaMemcpy(dk, hkb, B*s_k*HD*2, cudaMemcpyHostToDevice);
|
||||
cudaMemset(ds_out, 0, s_k*4);
|
||||
|
||||
// SMEM = 4 (tmem_base) + 128B align + 128*HD*2*2 (sQ + sK, swizzled) + slack
|
||||
// Swizzled layout may use more space than row-major due to atom padding
|
||||
// MN_SW128 atom = 1024*8 = 8192 BF16 per atom. For (128, HD) = 128*HD BF16.
|
||||
// With 1 MN tile and ceil(HD/8) K tiles, total = 8192 * ceil(HD/8) BF16
|
||||
int atoms_n = (HD + 7) / 8; // number of K-tiles (atom has 8 BF16 in K dim)
|
||||
int smem_q = 1024 * 8 * atoms_n; // BF16 elements for Q (padded)
|
||||
int smem_k = 1024 * 8 * atoms_n; // BF16 elements for K (padded)
|
||||
int smem = 4 + 128 + smem_q * 2 + smem_k * 2 + 1024; // bytes
|
||||
// SMEM = 4 (tmem_base) + 128B align + 128*HD*2 (sQ) + 128*HD*2 (sK) + slack
|
||||
int smem = 4 + 128 + 128 * HD * 2 + 128 * HD * 2 + 4096;
|
||||
|
||||
dim3 grid(1, H, B);
|
||||
dim3 block(NTHREADS);
|
||||
|
||||
Reference in New Issue
Block a user