Start implementing manual SMEM-P addressing (helpers are a trap)

This commit is contained in:
2026-05-23 19:20:40 +00:00
parent 7bf69a0265
commit 6c08a95620

View File

@@ -257,26 +257,12 @@ class FmhaKernel:
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# P SMEM copy atoms: SMEM-P (always defined, only used when use_smem_p=True)
# Uses make_tiled_copy_C to partition threads by QK MMA's C-fragment layout.
# Softmax warps have P values in QK C-fragment layout (same as rP_bf16).
# This copy writes those values to sP which has PV A-operand SMEM layout.
# According to STAGE_D.md: use tcgen05.copy.St32x32bOp with Float32 (not BF16)
# and use make_tiled_copy_C(store_atom, qk_mma) to partition by QK C-fragment
# BUT: sP is BF16, so we need BF16 copy atom or convert Float32→BF16
smem_copy_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)),
self.q_dtype, # BF16 to match sP
num_bits_per_copy=128,
)
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx)
# Debug sP shape
print(f"[SMEM-P PROPER] sP shape: {cute.shape(sP)} rank: {len(cute.shape(sP))}")
# Don't flatten - keep original sP
tSMEM_CPYsP = thr_smem_copy.partition_D(sP) # destination (SMEM)
print(f"[SMEM-P PROPER] After partition_D: tSMEM_CPYsP layout: {tSMEM_CPYsP.layout}")
print(f"[SMEM-P PROPER] After partition_D: tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}")
# Manual SMEM addressing for P (helpers are a trap)
# We need to write P values from QK C-fragment layout to PV A-operand SMEM layout
# sP has PV A-operand SMEM layout: p_smem_s
print(f"[SMEM-P MANUAL] Starting manual SMEM addressing")
print(f"[SMEM-P MANUAL] sP shape: {cute.shape(sP)} layout: {sP.layout}")
print(f"[SMEM-P MANUAL] p_smem_s (PV A-operand SMEM layout): {p_smem_s}")
row_max = -Float32.inf
row_sum = Float32(0.0)
@@ -349,43 +335,39 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: Use QK C-fragment layout for source (not TMEM layout)
# rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch
# Create BF16 view with QK C-fragment layout for copying to BF16 SMEM
rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread
rP_bf16_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout)
# SMEM-P: Manual addressing (helpers are a trap)
# Each softmax thread owns P values in QK C-fragment partition
# Need to write to SMEM with PV A-operand layout
# Partition source with QK layout
print(f"[SMEM-P PROPER] Before partition_S: rP_bf16_qk shape: {cute.shape(rP_bf16_qk)} layout: {rP_bf16_qk.layout}")
tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_bf16_qk)
print(f"[SMEM-P PROPER] After partition_S: tSMEM_CPYrP_qk layout: {tSMEM_CPYrP_qk.layout}")
print(f"[SMEM-P PROPER] After partition_S: tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}")
print(f"[SMEM-P MANUAL] Starting manual P write to SMEM")
# Debug shapes
print(f"[SMEM-P PROPER] rP_bf16 shape: {cute.shape(rP_bf16)}, layout: TMEM")
print(f"[SMEM-P PROPER] rP_bf16_qk shape: {cute.shape(rP_bf16_qk)}, layout: QK C-fragment (BF16)")
print(f"[SMEM-P PROPER] tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}")
print(f"[SMEM-P PROPER] tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}")
# Get thread's logical position in QK C-fragment partition
# tStS0 is QK C-fragment tensor for this thread
thread_qk_coord = cute.coord(tStS0) # Logical coordinates in QK C-fragment space
print(f"[SMEM-P MANUAL] Thread QK coord: {thread_qk_coord}")
# Manual copy instead of cute.copy (helpers are a trap)
# Get sizes and iterate
src_size = cute.size(tSMEM_CPYrP_qk)
dst_size = cute.size(tSMEM_CPYsP)
print(f"[SMEM-P PROPER] Manual copy: src_size={src_size}, dst_size={dst_size}")
# Get the shape of P values this thread owns
# tStS0 has shape ((128, 128), 1, 1) - total 128×128 P matrix
# Each thread owns a subtile of this
qk_fragment_shape = cute.shape(tStS0)
print(f"[SMEM-P MANUAL] QK fragment shape: {qk_fragment_shape}")
if src_size == dst_size:
# Same number of elements - copy element by element
for j in cutlass.range(src_size, vectorize=True):
val = tSMEM_CPYrP_qk[j]
tSMEM_CPYsP[j] = val
print(f"[SMEM-P PROPER] Manual copy completed (src_size==dst_size)")
else:
print(f"[SMEM-P PROPER] Manual copy: size mismatch, using fallback stub")
# Fallback to stub for now
for j in cutlass.range(cute.size(sP), vectorize=True):
sP[j] = BFloat16(0.0)
print(f"[SMEM-P PROPER] Used fallback stub")
# For now, implement simple test: write thread ID to first element
# This is WRONG but helps debug
smem_test_offset = Int32(0)
sP[smem_test_offset] = BFloat16(float(sfw_idx) * 0.01)
print(f"[SMEM-P MANUAL] Wrote thread {sfw_idx} value to SMEM offset 0")
# TODO: Implement proper mapping
# 1. Determine which P indices this thread owns
# 2. For each P index (i,j), compute SMEM address in PV A-operand layout
# 3. Write P value to that address
# For now, zero out sP as fallback (produces wrong results but compiles)
for j in cutlass.range(cute.size(sP), vectorize=True):
sP[j] = BFloat16(0.0)
print(f"[SMEM-P MANUAL] Used zero-fallback (TODO: implement proper mapping)")
cute.arch.fence_proxy("async.shared", space="cta")
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
if kt > 0: