Fix O normalize: use 2D register tensor indexing
This commit is contained in:
@@ -299,26 +299,13 @@ class FmhaV3RealSoftmax:
|
||||
# Final O normalization: O = O / row_sum
|
||||
if row_sum != Float32(0.0):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
# Load O from TMEM, multiply by 1/row_sum, write back
|
||||
n_corr = 128 // corr_tile_size
|
||||
for ci in range(n_corr):
|
||||
tOtO_ci = cute.make_tensor(tOtO_i.iterator, tOtO_i.layout)
|
||||
tOcO_ci = cute.make_tensor(tOcO_i.iterator, tOcO_i.layout)
|
||||
# Sub-tile index offset
|
||||
tOtO_sub = tOtO_ci[None, ci, None]
|
||||
tOcO_sub = tOcO_ci[None, ci, None]
|
||||
# Use the base partitioned tensors with offset
|
||||
# Actually, just load the full O sub-tile
|
||||
pass
|
||||
# Simple approach: load/store via the partitioned tensors
|
||||
tTMEM_LOAD_OrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO, tTMEM_LOAD_OrO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
# Iterate with proper indexing
|
||||
n_o_frg = cute.size(tTMEM_LOAD_OrO, mode=[0])
|
||||
for fi in range(n_o_frg):
|
||||
# The register tensor from the O partition is 2D: (frg, corr_tile)
|
||||
for fi in range(cute.size(tTMEM_LOAD_OrO, mode=[0])):
|
||||
for fj in range(cute.size(tTMEM_LOAD_OrO, mode=[1])):
|
||||
tTMEM_LOAD_OrO[fi, fj, None] = tTMEM_LOAD_OrO[fi, fj, None] * inv_row_sum
|
||||
tTMEM_LOAD_OrO[fi, fj] = tTMEM_LOAD_OrO[fi, fj] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMEM_LOAD_OrO, tTMEM_STORE_OtO)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user