Stage C: add online O rescaling for multi-tile KV + test n=256
- Move O TMEM load/store setup before softmax loop - After P store: rescale O in TMEM by exp2((old_max - new_max) * scale) - Only rescale for kt > 0 (first tile has no prior O to rescale) - Use same TMEM load/modify/store pattern as final normalization - Test both n=128 (1 tile) and n=256 (2 tiles)
This commit is contained in:
@@ -236,6 +236,24 @@ class FmhaV3StageC:
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# --- Online softmax loop ---
|
||||
# --- O rescale + normalize setup (before softmax loop) ---
|
||||
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
corr_tile_size = 16
|
||||
tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout)
|
||||
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
|
||||
tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
|
||||
tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
|
||||
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
|
||||
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
|
||||
thr_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
|
||||
thr_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
|
||||
tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO_i)
|
||||
tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_i)
|
||||
tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_i)
|
||||
|
||||
row_max = -Float32.inf; row_sum = Float32(0.0)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
@@ -281,26 +299,25 @@ class FmhaV3StageC:
|
||||
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# O rescale: if kt > 0, rescale O in TMEM by exp2((old_max - new_max) * scale_log2)
|
||||
if kt > 0:
|
||||
corr_scale = cute.math.exp2(scale_log2 * (old_row_max - row_max_safe), fastmath=True)
|
||||
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
|
||||
for i in range(o_col_tiles):
|
||||
tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout)
|
||||
tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout)
|
||||
tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO)
|
||||
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
|
||||
tTMrO[k] = tTMrO[k] * corr_scale
|
||||
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# --- Normalize O in TMEM by row_sum ---
|
||||
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
corr_tile_size = 16
|
||||
tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout)
|
||||
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
|
||||
tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
|
||||
tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
|
||||
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
|
||||
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
|
||||
thr_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
|
||||
thr_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
|
||||
tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO_i)
|
||||
tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_i)
|
||||
tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_i)
|
||||
# --- Final O normalization by row_sum ---
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
|
||||
for i in range(o_col_tiles):
|
||||
@@ -326,8 +343,9 @@ class FmhaV3StageC:
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
for n in [128]:
|
||||
for seed in [42, 123, 999]:
|
||||
# Test single tile (n=128) and multi-tile (n=256)
|
||||
for n in [128, 256]:
|
||||
for seed in [42]:
|
||||
torch.manual_seed(seed)
|
||||
m, hd = 128, HEAD_DIM
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
@@ -345,7 +363,7 @@ def test():
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = FmhaV3StageC()
|
||||
kernel = FmhaV3StageC(s_k=128)
|
||||
if seed == 42:
|
||||
print(f'seed={seed}: Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
|
||||
Reference in New Issue
Block a user