D1: P store as BF16 using PV A-fragment layout

- Changed P store from FP32 QK C-fragment layout to BF16 PV A-fragment layout
- rP_bf16_reg stores directly to TMEM using tOrP0 layout
- Ensures softmax writes P to same TMEM columns that PV GEMM reads
This commit is contained in:
2026-05-23 03:38:24 +00:00
parent 2efd6be8af
commit 059c2e6cd9

View File

@@ -225,16 +225,16 @@ class FmhaKernel:
# P store atoms
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
# P store must use the PV A-fragment layout so PV reads correct TMEM columns.
# Use tOrP0's layout (PV A-fragment) for the store target, not QK C-fragment composition.
tStP0 = cute.make_tensor(tOrP0.iterator, tOrP0.layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
# P store: use PV A-fragment layout (tOrP0) so PV GEMM reads correct TMEM columns.
# Store as BF16 (matching PV A-fragment), not FP32 (which causes layout mismatch at hd>64).
tStP0_bf16 = cute.make_tensor(tOrP0.iterator, tOrP0.layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.q_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0_bf16)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
# tScP: coordinate tensor for P store — use PV A-fragment layout
tTMEM_STOREtP = thr_store.partition_D(tStP0_bf16)
# Coordinate tensor for P store
cP = cute.make_identity_tensor((self.pv_mma_tiler[0], p_cols_fp32))
tOcP = pv_thr.partition_A(cP) # partition using PV thread slice
tOcP = pv_thr.partition_A(cP)
tTMEM_STOREcP = thr_store.partition_S(cute.make_tensor(tOcP.iterator, tOrP0.layout))
row_max = -Float32.inf
@@ -290,11 +290,10 @@ class FmhaKernel:
acc_scale = Float32(0.0)
row_sum *= acc_scale
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
# Store P to TMEM using PV A-fragment layout (BF16)
rP_bf16_reg = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.q_dtype)
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
rP_bf16_frg = cute.logical_divide(rP_bf16_reg, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
@@ -303,7 +302,7 @@ class FmhaKernel:
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.copy(tiled_tmem_store, rP_bf16_reg, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
# Per-tile O rescale (hand-constructed atoms with logical_divide layout)