diff --git a/tests/fmha_v3_real_softmax.py b/tests/fmha_v3_real_softmax.py index 87459380..95056662 100644 --- a/tests/fmha_v3_real_softmax.py +++ b/tests/fmha_v3_real_softmax.py @@ -276,7 +276,29 @@ class FmhaV3RealSoftmax: acc_scale = Float32(0.0) row_sum *= acc_scale - # TODO: O rescale in TMEM (skip for now, test softmax + P only) + # O rescale in TMEM: multiply existing O by acc_scale = exp2(old_max - new_max) + # Only for kt > 0 (first tile: no existing O to rescale) + if kt > 0: + n_corr = 128 // corr_tile_size + tTMrO_rs = cute.make_rmem_tensor( + (tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype + ) + for ci in range(n_corr): + tTMrO_ci_ = tTMrO_rs[None, ci] + tTMrO_ci_layout = cute.composition( + tTMrO_ci_.layout, cute.make_layout(tTMrO_rs.shape[0]) + ) + tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout) + tTMEM_LOAD_OtO_ci = cute.make_tensor( + tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout + ) + tTMEM_STORE_OtO_ci = cute.make_tensor( + tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci) + for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True): + tTMrO_ci[j] = tTMrO_ci[j] * acc_scale + cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci) # Pass 2: P = exp2(S * scale_log2 - row_max), accumulate row_sum rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) @@ -296,9 +318,29 @@ class FmhaV3RealSoftmax: si_handle.release() softmax_done_bar.arrive() - # TODO: Final O normalization (disabled — corrupts output) - # The O sub-tile read-modify-write needs more debugging. - # For now, verify softmax P computation is correct with unnormalized output. + # Final O normalization: O = O / row_sum + if row_sum != Float32(0.0): + inv_row_sum = Float32(1.0) / row_sum + n_corr = 128 // corr_tile_size + tTMrO_fn = cute.make_rmem_tensor( + (tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype + ) + for ci in range(n_corr): + tTMrO_ci_ = tTMrO_fn[None, ci] + tTMrO_ci_layout = cute.composition( + tTMrO_ci_.layout, cute.make_layout(tTMrO_fn.shape[0]) + ) + tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout) + tTMEM_LOAD_OtO_ci = cute.make_tensor( + tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout + ) + tTMEM_STORE_OtO_ci = cute.make_tensor( + tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci) + for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True): + tTMrO_ci[j] = tTMrO_ci[j] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci) # Epilogue: TMEM -> SMEM -> GMEM via TMA store tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) @@ -321,15 +363,13 @@ def test(): v_kernel = v.unsqueeze(-1) c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') - # Reference: unnormalized softmax numerators @ V - # (no O rescale or 1/row_sum normalization in kernel yet) + # Reference: proper softmax qf = q[:, :, 0].float() kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) attn = qf @ kf.T * scale - attn_max = attn.max(dim=-1, keepdim=True)[0] - attn_unnorm = torch.exp(attn - attn_max) # softmax numerators - ref = attn_unnorm @ v.float() + attn = torch.softmax(attn, dim=-1) + ref = attn @ v.float() mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))