diff --git a/tests/unit/test_fmha_v3_softmax.py b/tests/unit/test_fmha_v3_softmax.py index ef9029bf..69d00f2f 100644 --- a/tests/unit/test_fmha_v3_softmax.py +++ b/tests/unit/test_fmha_v3_softmax.py @@ -373,37 +373,9 @@ class FmhaV3Softmax: row_sum = row_sum + tile_sum - # --- C9: Final normalization via O TMEM rescale --- + # --- C9: SKIPPED for debug (no normalization) --- pv_done_bar.arrive_and_wait() - # Store final row_sum to TMEM vector (per-row, using QK partition) - tTMEM_STORE_VECrS_final = cute.make_rmem_tensor(tTMEM_STORE_VECcS.shape, self.qk_acc_dtype) - tTMEM_STORE_VECrS_final[0] = row_sum - cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS_final, tTMEM_STORE_VECtS) - cute.arch.fence_view_async_tmem_store() - - # Read vector back: per-row row_sum using QK partition coordinates - tTMEM_LOAD_VECrS = cute.make_rmem_tensor(tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype) - cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS, tTMEM_LOAD_VECrS) - cute.arch.fence_view_async_tmem_load() - inv_row_sum = cutlass.Float32(1.0) / tTMEM_LOAD_VECrS[0] - - # Normalize O in TMEM - tTMrO_final = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype) - for i in range(o_col_tiles): - tTMrO_i_ = tTMrO_final[None, i] - tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMrO_final.shape[0])) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout) - cute.copy(o_tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i) - for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[j] = tTMrO_i[j] * inv_row_sum - cute.copy(o_tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i) - cute.arch.fence_view_async_tmem_store() - - # Now O in TMEM is normalized. Use standard epilogue_tma_store with identity. + # O is unnormalized in TMEM. Use standard epilogue_tma_store with identity. tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) @@ -429,7 +401,8 @@ def test(): c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device="cuda") qf = q[:,:,0].float(); kf = k[:,:,0].float() attn = qf @ kf.T / math.sqrt(hd) - ref = torch.softmax(attn, dim=-1) @ v.float() + P = torch.exp(attn - attn.max(dim=-1, keepdim=True)[0]) + ref = P @ v.float() # unnormalized P@V mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) @@ -445,7 +418,7 @@ def test(): out = c[:,:,0].float() cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() max_err = (out - ref).abs().max().item() - print(f"FMHA softmax n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True) + print(f"FMHA softmax (no C9 norm) n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True) if __name__ == "__main__": test()