From f2592ea0da27801ceb59b4b7b158aa92beafdb60 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 07:01:42 +0000 Subject: [PATCH] fix: native TMEM columns for hd_chunk (no remapping) --- .../fmha_6warp_tma_multirow_multitile.cuh | 28 ++++--------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh index 81c6d98d..26bcfd17 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh @@ -130,17 +130,6 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) const int n_sub_end = n_sub_start + N_NSUB_CHUNK; int phase = 0; // reset mbarrier phase per hd_chunk - // DEBUG: skip hd_chunk > 0 to test just chunk 0 - // DEBUG: skip hd_chunk 0 to test just chunk 1 - bool skip_chunk = (hd_chunk == 0); // test chunk 1 only - if (skip_chunk) { - if (my_row_active) { - for (int d = 0; d < HD_CHUNK; d++) - o_head[my_row * HD + hd_chunk_start + d] = f32_to_bf16(0.0f); - } - continue; - } - // Zero accumulator for (int i = tid; i < MAX_ROWS * HD_CHUNK; i += 192) sOacc[i] = 0.0f; for (int i = tid; i < MAX_ROWS; i += 192) { @@ -288,8 +277,8 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128); uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); // TMEM column offset: (n_sub - n_sub_start) * 16 - int tmem_col = (n_sub - n_sub_start) * 16; - if (tid == 128) umma_ss_f16(tb + tmem_col, dp, dv, idesc_pv, pv_kt > 0); + // PV output at NATIVE TMEM column (not remapped) + if (tid == 128) umma_ss_f16(tb + n_sub * 16, dp, dv, idesc_pv, pv_kt > 0); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); } __syncthreads(); @@ -303,14 +292,16 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) float tile_max = my_row_active ? sTileRowMax[my_row] : 0.0f; float tile_sum = my_row_active ? sTileRowSum[my_row] : 0.0f; - // Read O_chunk from TMEM (warp-collective — outside my_row_active guard) + // Read O_chunk from TMEM at NATIVE column offsets (warp-collective) float o_tile_buf[HD_CHUNK]; for (int n = 0; n < TMEM_READS_PER_CHUNK; n++) { float tmp[8]; + // Read from hd_chunk's native TMEM columns + int tmem_read_addr = tb + (n_sub_start * 16) + n * 8; asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) - : "r"(tb + n * 8)); + : "r"(tmem_read_addr)); asm volatile("tcgen05.wait::ld.sync.aligned;"); if (my_row_active) { for (int c = 0; c < 8; c++) { @@ -341,13 +332,6 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) // ---- Write chunk to GMEM ---- if (my_row_active) { float inv_rs = 1.0f / sRunningSum[my_row]; - // Debug: for first hd_chunk, first row, print some values - if (my_row == 0 && hd_chunk == 0) { - // Print first 8 O values and running stats - printf("hd_chunk=%d rmax=%f rsum=%f O[0..3]=%f %f %f %f\n", - hd_chunk, sRunningMax[0], sRunningSum[0], - sOacc[0] * inv_rs, sOacc[1] * inv_rs, sOacc[2] * inv_rs, sOacc[3] * inv_rs); - } for (int d = 0; d < HD_CHUNK; d++) { o_head[my_row * HD + hd_chunk_start + d] = f32_to_bf16(sOacc[my_row * HD_CHUNK + d] * inv_rs); }