diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index d84c773e..f16a7c14 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -236,6 +236,24 @@ class FmhaV3StageC: tTMEM_STOREcP = thr_store.partition_S(tScP) # --- Online softmax loop --- + # --- O rescale + normalize setup (before softmax loop) --- + cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO = pv_thr.partition_C(cO) + corr_tile_size = 16 + tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) + tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) + tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) + tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) + tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) + thr_load_o = tiled_tmem_load_o.get_slice(sfw_idx) + thr_store_o = tiled_tmem_store_o.get_slice(sfw_idx) + tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO_i) + tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_i) + tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_i) + row_max = -Float32.inf; row_sum = Float32(0.0) scale_log2 = Float32(self.scale_softmax_log2) @@ -281,26 +299,25 @@ class FmhaV3StageC: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() + + # O rescale: if kt > 0, rescale O in TMEM by exp2((old_max - new_max) * scale_log2) + if kt > 0: + corr_scale = cute.math.exp2(scale_log2 * (old_row_max - row_max_safe), fastmath=True) + o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size + for i in range(o_col_tiles): + tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout) + tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout) + tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO) + for k in cutlass.range(cute.size(tTMrO), vectorize=True): + tTMrO[k] = tTMrO[k] * corr_scale + cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i) + cute.arch.fence_view_async_tmem_store() + si_handle.release() softmax_done_bar.arrive() - # --- Normalize O in TMEM by row_sum --- - cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) - tOcO = pv_thr.partition_C(cO) - corr_tile_size = 16 - tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) - tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) - tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) - thr_load_o = tiled_tmem_load_o.get_slice(sfw_idx) - thr_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO_i) - tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_i) - tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_i) + # --- Final O normalization by row_sum --- inv_row_sum = Float32(1.0) / row_sum o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size for i in range(o_col_tiles): @@ -326,8 +343,9 @@ class FmhaV3StageC: def test(): torch.manual_seed(42) - for n in [128]: - for seed in [42, 123, 999]: + # Test single tile (n=128) and multi-tile (n=256) + for n in [128, 256]: + for seed in [42]: torch.manual_seed(seed) m, hd = 128, HEAD_DIM q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') @@ -345,7 +363,7 @@ def test(): mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - kernel = FmhaV3StageC() + kernel = FmhaV3StageC(s_k=128) if seed == 42: print(f'seed={seed}: Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)