Fix SMEM-P copy: use tcgen05.copy.St32x32bOp with Float32 and copy from rP_words (Float32) not rP_bf16
This commit is contained in:
@@ -261,9 +261,11 @@ class FmhaKernel:
|
||||
# Uses make_tiled_copy_C to partition threads by QK MMA's C-fragment layout.
|
||||
# Softmax warps have P values in QK C-fragment layout (same as rP_bf16).
|
||||
# This copy writes those values to sP which has PV A-operand SMEM layout.
|
||||
# According to STAGE_D.md: use tcgen05.copy.St32x32bOp with Float32 (not BF16)
|
||||
# and use make_tiled_copy_C(store_atom, qk_mma) to partition by QK C-fragment
|
||||
smem_copy_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.q_dtype,
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)),
|
||||
Float32,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
|
||||
@@ -344,16 +346,16 @@ class FmhaKernel:
|
||||
else:
|
||||
# SMEM-P: Use QK C-fragment layout for source (not TMEM layout)
|
||||
# rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch
|
||||
# Create view with QK C-fragment layout (tStS0.layout)
|
||||
# Create view with QK C-fragment layout (tStS0.layout) using Float32 source (rP_words)
|
||||
rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread
|
||||
rP_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout)
|
||||
rP_qk = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=Float32), rP_qk_layout)
|
||||
|
||||
# Partition source with QK layout
|
||||
tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_qk)
|
||||
|
||||
# Debug shapes
|
||||
print(f"[SMEM-P PROPER] rP_bf16 shape: {cute.shape(rP_bf16)}, layout: TMEM")
|
||||
print(f"[SMEM-P PROPER] rP_qk shape: {cute.shape(rP_qk)}, layout: QK C-fragment")
|
||||
print(f"[SMEM-P PROPER] rP_qk shape: {cute.shape(rP_qk)}, layout: QK C-fragment (Float32)")
|
||||
print(f"[SMEM-P PROPER] tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}")
|
||||
print(f"[SMEM-P PROPER] tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user