SMEM-P: add debug_p_one flag to write constant P=1.0
This commit is contained in:
@@ -22,6 +22,7 @@ class FmhaKernel:
|
||||
self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256
|
||||
self.n_pv_tiles = head_dim // self.pv_n_tile
|
||||
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
|
||||
self.debug_p_one = True # DEBUG: write constant P=1.0 to verify mapping
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
@@ -360,6 +361,11 @@ class FmhaKernel:
|
||||
# Compute inverse row sum for normalization
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
# DEBUG: If debug flag set, write constant P=1.0 to verify mapping
|
||||
if self.debug_p_one:
|
||||
inv_row_sum = Float32(1.0)
|
||||
print("[DEBUG] Writing constant P=1.0 to verify SMEM mapping")
|
||||
|
||||
# Phase 2: Normalize P values and write to SMEM (if using SMEM-P)
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
|
||||
Reference in New Issue
Block a user