Fix p_cols_fp32: use pv_mma_tiler[2] (K-dim) not [1] (N-dim)
This commit is contained in:
@@ -346,7 +346,7 @@ class FmhaKernel:
|
||||
|
||||
# ── P store setup (always define both paths — CuTeDSL scoping) ──
|
||||
# TMEM-P: register bridge for P → TMEM
|
||||
p_cols_fp32 = self.pv_mma_tiler[1] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
|
||||
Reference in New Issue
Block a user