D1: align P store and PV A-fragment layouts via tP
Key insight: tP (PV A-fragment base) used p_tmem_s.outer layout, but P store used QK C-fragment composition layout. These diverge at hd>64. Fix: tP now uses the same QK C-fragment composition layout (tStP_layout) as the P store. PV A-fragment is derived from tP, so it automatically uses the same layout. No double-offset since tP includes P offset.
This commit is contained in:
@@ -157,12 +157,18 @@ class FmhaKernel:
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
# P tensor for PV A-fragment: use QK C-fragment composition layout so P store and PV read agree
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
# PV A-fragment offset: p0_offset in BF16 columns = p0_offset * (FP32_width / BF16_width)
|
||||
tP = cute.make_tensor(
|
||||
tStS.iterator + self.tmem_p0_offset * (self.qk_acc_dtype.width // self.q_dtype.width),
|
||||
tStP_layout,
|
||||
)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None,None,None,0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
# tP already starts at P offset, no additional offset needed
|
||||
tOrP0 = tOrP
|
||||
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
@@ -225,17 +231,15 @@ class FmhaKernel:
|
||||
|
||||
# P store atoms
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
# P store: use PV A-fragment layout (tOrP0) so PV GEMM reads correct TMEM columns.
|
||||
# Store as BF16 (matching PV A-fragment), not FP32 (which causes layout mismatch at hd>64).
|
||||
tStP0_bf16 = cute.make_tensor(tOrP0.iterator, tOrP0.layout)
|
||||
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.q_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0_bf16)
|
||||
# P store uses QK C-fragment composition layout (same layout as tP/PV A-fragment)
|
||||
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_store.partition_D(tStP0_bf16)
|
||||
# Coordinate tensor for P store
|
||||
cP = cute.make_identity_tensor((self.pv_mma_tiler[0], p_cols_fp32))
|
||||
tOcP = pv_thr.partition_A(cP)
|
||||
tTMEM_STOREcP = thr_store.partition_S(cute.make_tensor(tOcP.iterator, tOrP0.layout))
|
||||
tTMEM_STOREtP = thr_store.partition_D(tStP0)
|
||||
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
@@ -290,10 +294,12 @@ class FmhaKernel:
|
||||
acc_scale = Float32(0.0)
|
||||
row_sum *= acc_scale
|
||||
|
||||
# Store P to TMEM using PV A-fragment layout (BF16)
|
||||
rP_bf16_reg = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.q_dtype)
|
||||
# Store P to TMEM (FP32, using QK C-fragment composition layout)
|
||||
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
|
||||
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
|
||||
minus_row_max = Float32(0.0) - row_max_safe
|
||||
rP_bf16_frg = cute.logical_divide(rP_bf16_reg, cute.make_layout(frg_tile))
|
||||
|
||||
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
|
||||
@@ -302,7 +308,7 @@ class FmhaKernel:
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, rP_bf16_reg, tTMEM_STOREtP)
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Per-tile O rescale (hand-constructed atoms with logical_divide layout)
|
||||
|
||||
Reference in New Issue
Block a user