D1: align P store and PV A-fragment layouts via tP

Key insight: tP (PV A-fragment base) used p_tmem_s.outer layout,
but P store used QK C-fragment composition layout. These diverge at hd>64.

Fix: tP now uses the same QK C-fragment composition layout (tStP_layout)
as the P store. PV A-fragment is derived from tP, so it automatically
uses the same layout. No double-offset since tP includes P offset.
This commit is contained in:
2026-05-23 03:40:10 +00:00
parent 95cf4159f2
commit 902bda5c31

View File

@@ -157,12 +157,18 @@ class FmhaKernel:
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
# P tensor for PV A-fragment: use QK C-fragment composition layout so P store and PV read agree
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
# PV A-fragment offset: p0_offset in BF16 columns = p0_offset * (FP32_width / BF16_width)
tP = cute.make_tensor(
tStS.iterator + self.tmem_p0_offset * (self.qk_acc_dtype.width // self.q_dtype.width),
tStP_layout,
)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None,None,None,0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# tP already starts at P offset, no additional offset needed
tOrP0 = tOrP
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
@@ -225,17 +231,15 @@ class FmhaKernel:
# P store atoms
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
# P store: use PV A-fragment layout (tOrP0) so PV GEMM reads correct TMEM columns.
# Store as BF16 (matching PV A-fragment), not FP32 (which causes layout mismatch at hd>64).
tStP0_bf16 = cute.make_tensor(tOrP0.iterator, tOrP0.layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.q_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0_bf16)
# P store uses QK C-fragment composition layout (same layout as tP/PV A-fragment)
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0_bf16)
# Coordinate tensor for P store
cP = cute.make_identity_tensor((self.pv_mma_tiler[0], p_cols_fp32))
tOcP = pv_thr.partition_A(cP)
tTMEM_STOREcP = thr_store.partition_S(cute.make_tensor(tOcP.iterator, tOrP0.layout))
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
row_max = -Float32.inf
row_sum = Float32(0.0)
@@ -290,10 +294,12 @@ class FmhaKernel:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# Store P to TMEM using PV A-fragment layout (BF16)
rP_bf16_reg = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.q_dtype)
# Store P to TMEM (FP32, using QK C-fragment composition layout)
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16_frg = cute.logical_divide(rP_bf16_reg, cute.make_layout(frg_tile))
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
@@ -302,7 +308,7 @@ class FmhaKernel:
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, rP_bf16_reg, tTMEM_STOREtP)
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
# Per-tile O rescale (hand-constructed atoms with logical_divide layout)