From 1d397c8b67f8241439a99dabb1d876ffbd039335 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 01:35:18 +0000 Subject: [PATCH] diag: skip kernel normalize, do Python-side normalize to isolate TMEM round-trip issue --- tests/unit/test_fmha_v3_stage_c.py | 75 +++++++++++++++--------------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 52ed0b3f..e2056dae 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -375,18 +375,8 @@ class FmhaV3StageCMulti: cute.arch.fence_view_async_tmem_store() # === Per-tile O rescale: O *= acc_scale for kt > 0 === - # Uses 2D register tensor pattern (same as CUTLASS correction_rescale - # and our final normalize) to preserve data through TMEM round-trip. if kt > 0: - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) tTMEM_LOADtO_i = cute.make_tensor( tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout, @@ -395,10 +385,12 @@ class FmhaV3StageCMulti: tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout, ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[k] = tTMrO_i[k] * acc_scale - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO) + cute.arch.fence_view_async_tmem_load() + for k in cutlass.range(cute.size(tTMrO), vectorize=True): + tTMrO[k] = tTMrO[k] * acc_scale + cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i) cute.arch.fence_view_async_tmem_store() si_handle.release() @@ -407,8 +399,16 @@ class FmhaV3StageCMulti: # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() - # DIAG: skip final normalize, just use epilogue_tma_store directly - # to test raw PV output + # === Final O normalization: O *= 1/row_sum === + # DIAG: NO-OP TMEM round-trip test — load and store back without modifying + inv_row_sum = Float32(1.0) / row_sum + + # SKIP the TMEM round-trip normalize entirely + # Just use epilogue_tma_store to read raw PV from TMEM + # The inv_row_sum normalization will be applied in Python for now + + # Standard epilogue: TMEM → SMEM → GMEM via TMA store. + # O in TMEM is now scaled by 1/row_sum. tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) acc_cons_st = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_acc_stage @@ -436,25 +436,17 @@ def test(): v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') v_kernel = v.unsqueeze(-1) c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') - debug = torch.zeros(4, dtype=torch.float32, device='cuda') # [row_sum, row_max, inv_row_sum, 0] qf = q[:, :, 0].float() kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) - attn_raw = qf @ kf.T * scale - attn = torch.softmax(attn_raw, dim=-1) + attn = qf @ kf.T * scale + attn = torch.softmax(attn, dim=-1) ref = attn @ v.float() - # Expected stats for comparison - print(f' row_sum (should be 1.0): {attn.sum(dim=-1)[:4].tolist()}') - # Unnormalized softmax: exp(S - max) - S_max = attn_raw.max(dim=-1, keepdim=True).values - P_unnorm = torch.exp(attn_raw - S_max) - unnorm_pv = P_unnorm @ v.float() - unnorm_sum = P_unnorm.sum(dim=-1) - print(f' unnorm row_sum: {unnorm_sum[:4].tolist()}') - print(f' unnorm P@V[0,:4]: {unnorm_pv[0,:4].tolist()}') - print(f' kernel out[0,:4] should match unnorm P@V (no normalize)') + # Also compute the unnormalized PV and row_sum for Python-side normalize + attn_unnorm = torch.exp(attn - attn.max(dim=-1, keepdim=True).values) + row_sum_unnorm = attn_unnorm.sum(dim=-1, keepdim=True) mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) @@ -473,17 +465,26 @@ def test(): torch.cuda.synchronize() out = c[:, :, 0].float() - cos = torch.nn.functional.cosine_similarity( - out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) + + # Python-side normalize: out is raw P@V (unnormalized) + # Divide by row_sum to get the correct softmax output + out_normalized = out / row_sum_unnorm + cos_raw = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), (attn_unnorm @ v.float()).flatten().unsqueeze(0) ).item() - max_abs = (out - ref).abs().max().item() + cos_norm = torch.nn.functional.cosine_similarity( + out_normalized.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) + ).item() + max_abs = (out_normalized - ref).abs().max().item() n_tiles = n // 128 - print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): ' - f'cos {cos:.6f} max_abs {max_abs:.4f} ' - f'{"PASS" if cos >= 0.99 else "FAIL"}') - if cos < 0.99: - print(f' out[0,:4]={out[0,:4].tolist()}') + print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles):', flush=True) + print(f' Raw PV (no normalize) vs unnorm ref: cos {cos_raw:.6f}', flush=True) + print(f' After Python normalize vs softmax ref: cos {cos_norm:.6f} max_abs {max_abs:.4f} ' + f'{"PASS" if cos_norm >= 0.99 else "FAIL"}', flush=True) + if cos_norm < 0.99: + print(f' out_normalized[0,:4]={out_normalized[0,:4].tolist()}') print(f' ref[0,:4]={ref[0,:4].tolist()}') + print(f' row_sum_unnorm[:4]={row_sum_unnorm[:4,0].tolist()}') if __name__ == '__main__':