diff --git a/tests/fmha_v3_real_softmax.py b/tests/fmha_v3_real_softmax.py index 230b5281..0767596a 100644 --- a/tests/fmha_v3_real_softmax.py +++ b/tests/fmha_v3_real_softmax.py @@ -310,22 +310,20 @@ class FmhaV3RealSoftmax: si_handle.release() softmax_done_bar.arrive() - # Final O normalization: O = O / row_sum - if row_sum != Float32(0.0): - inv_row_sum = Float32(1.0) / row_sum - n_corr = HEAD_DIM // corr_tile_size - for ci in range(n_corr): - tTMrO_fn = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) - tTMEM_LOAD_OtO_ci = cute.make_tensor( - tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout - ) - tTMEM_STORE_OtO_ci = cute.make_tensor( - tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_fn) - for j in cutlass.range(cute.size(tTMrO_fn), vectorize=True): - tTMrO_fn[j] = tTMrO_fn[j] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO_fn, tTMEM_STORE_OtO_ci) + # Final O normalization: load, NO-OP write (debug copy) + # If this matches unnormalized output, the TMEM copy is correct. + n_corr = HEAD_DIM // corr_tile_size + for ci in range(n_corr): + tTMrO_fn = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) + tTMEM_LOAD_OtO_ci = cute.make_tensor( + tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout + ) + tTMEM_STORE_OtO_ci = cute.make_tensor( + tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_fn) + # NO-OP: write back without modification + cute.copy(tiled_tmem_store_o, tTMrO_fn, tTMEM_STORE_OtO_ci) # Epilogue: TMEM -> SMEM -> GMEM via TMA store tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) @@ -348,13 +346,14 @@ def test(): v_kernel = v.unsqueeze(-1) c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') - # Reference: proper softmax + # Reference: unnormalized (just softmax numerators @ V, no 1/row_sum) qf = q[:, :, 0].float() kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) attn = qf @ kf.T * scale - attn = torch.softmax(attn, dim=-1) - ref = attn @ v.float() + attn_max = attn.max(dim=-1, keepdim=True)[0] + attn_unnorm = torch.exp(attn - attn_max) + ref = attn_unnorm @ v.float() mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))