Fix TMEM-P offset calc: match Stage C with p_cols_fp32 from pv_mma_tiler[2]
This commit is contained in:
@@ -72,8 +72,10 @@ class FmhaKernel:
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
s_cols = self.qk_mma_tiler[1]
|
||||
p_cols = self.pv_mma_tiler[1] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
self.tmem_o0_offset = max(s_cols, p_cols)
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
p_end = self.tmem_p0_offset + p_cols_fp32
|
||||
o_after = max(s_cols, p_end)
|
||||
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
total = self.tmem_o0_offset + o_cols
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user