D1.3: DIAGNOSTIC - test epilogue_tma_store raw PV without any round-trips
This commit is contained in:
@@ -369,39 +369,9 @@ class FmhaKernel:
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# === O normalization: TMEM -> reg (scale by 1/row_sum) -> TMEM ===
|
||||
# Uses hand-constructed Ld32x32bOp/St32x32bOp atoms (same as correction_rescale).
|
||||
# The layout mismatch in these atoms introduces ~3% error per round-trip,
|
||||
# but the correction_rescale atoms (same construction) already use this path.
|
||||
# TODO: Replace with get_tmem_load_op-derived atoms for zero error.
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Epilogue: TMEM → SMEM → GMEM via TMA store.
|
||||
# Uses epilogue_tmem_copy_and_partition (get_tmem_load_op) internally.
|
||||
# Since O is already normalized in TMEM, we apply identity epilogue_op.
|
||||
# === DIAGNOSTIC: Test epilogue_tma_store WITHOUT any round-trips ===
|
||||
# If get_tmem_load_op reads O correctly from TMEM, this should give cos 0.9999
|
||||
# (un-normalized, just raw PV sum). Then we can add normalization back.
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
|
||||
Reference in New Issue
Block a user