fix: native TMEM columns for hd_chunk (no remapping)
This commit is contained in:
@@ -130,17 +130,6 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
|
||||
const int n_sub_end = n_sub_start + N_NSUB_CHUNK;
|
||||
int phase = 0; // reset mbarrier phase per hd_chunk
|
||||
|
||||
// DEBUG: skip hd_chunk > 0 to test just chunk 0
|
||||
// DEBUG: skip hd_chunk 0 to test just chunk 1
|
||||
bool skip_chunk = (hd_chunk == 0); // test chunk 1 only
|
||||
if (skip_chunk) {
|
||||
if (my_row_active) {
|
||||
for (int d = 0; d < HD_CHUNK; d++)
|
||||
o_head[my_row * HD + hd_chunk_start + d] = f32_to_bf16(0.0f);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Zero accumulator
|
||||
for (int i = tid; i < MAX_ROWS * HD_CHUNK; i += 192) sOacc[i] = 0.0f;
|
||||
for (int i = tid; i < MAX_ROWS; i += 192) {
|
||||
@@ -288,8 +277,8 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
|
||||
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128);
|
||||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||||
// TMEM column offset: (n_sub - n_sub_start) * 16
|
||||
int tmem_col = (n_sub - n_sub_start) * 16;
|
||||
if (tid == 128) umma_ss_f16(tb + tmem_col, dp, dv, idesc_pv, pv_kt > 0);
|
||||
// PV output at NATIVE TMEM column (not remapped)
|
||||
if (tid == 128) umma_ss_f16(tb + n_sub * 16, dp, dv, idesc_pv, pv_kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -303,14 +292,16 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
|
||||
float tile_max = my_row_active ? sTileRowMax[my_row] : 0.0f;
|
||||
float tile_sum = my_row_active ? sTileRowSum[my_row] : 0.0f;
|
||||
|
||||
// Read O_chunk from TMEM (warp-collective — outside my_row_active guard)
|
||||
// Read O_chunk from TMEM at NATIVE column offsets (warp-collective)
|
||||
float o_tile_buf[HD_CHUNK];
|
||||
for (int n = 0; n < TMEM_READS_PER_CHUNK; n++) {
|
||||
float tmp[8];
|
||||
// Read from hd_chunk's native TMEM columns
|
||||
int tmem_read_addr = tb + (n_sub_start * 16) + n * 8;
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n * 8));
|
||||
: "r"(tmem_read_addr));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (my_row_active) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
@@ -341,13 +332,6 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
|
||||
// ---- Write chunk to GMEM ----
|
||||
if (my_row_active) {
|
||||
float inv_rs = 1.0f / sRunningSum[my_row];
|
||||
// Debug: for first hd_chunk, first row, print some values
|
||||
if (my_row == 0 && hd_chunk == 0) {
|
||||
// Print first 8 O values and running stats
|
||||
printf("hd_chunk=%d rmax=%f rsum=%f O[0..3]=%f %f %f %f\n",
|
||||
hd_chunk, sRunningMax[0], sRunningSum[0],
|
||||
sOacc[0] * inv_rs, sOacc[1] * inv_rs, sOacc[2] * inv_rs, sOacc[3] * inv_rs);
|
||||
}
|
||||
for (int d = 0; d < HD_CHUNK; d++) {
|
||||
o_head[my_row * HD + hd_chunk_start + d] = f32_to_bf16(sOacc[my_row * HD_CHUNK + d] * inv_rs);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user