Start implementing manual SMEM-P addressing (helpers are a trap)
This commit is contained in:
@@ -257,26 +257,12 @@ class FmhaKernel:
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# P SMEM copy atoms: SMEM-P (always defined, only used when use_smem_p=True)
|
||||
# Uses make_tiled_copy_C to partition threads by QK MMA's C-fragment layout.
|
||||
# Softmax warps have P values in QK C-fragment layout (same as rP_bf16).
|
||||
# This copy writes those values to sP which has PV A-operand SMEM layout.
|
||||
# According to STAGE_D.md: use tcgen05.copy.St32x32bOp with Float32 (not BF16)
|
||||
# and use make_tiled_copy_C(store_atom, qk_mma) to partition by QK C-fragment
|
||||
# BUT: sP is BF16, so we need BF16 copy atom or convert Float32→BF16
|
||||
smem_copy_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)),
|
||||
self.q_dtype, # BF16 to match sP
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
|
||||
thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx)
|
||||
# Debug sP shape
|
||||
print(f"[SMEM-P PROPER] sP shape: {cute.shape(sP)} rank: {len(cute.shape(sP))}")
|
||||
# Don't flatten - keep original sP
|
||||
tSMEM_CPYsP = thr_smem_copy.partition_D(sP) # destination (SMEM)
|
||||
print(f"[SMEM-P PROPER] After partition_D: tSMEM_CPYsP layout: {tSMEM_CPYsP.layout}")
|
||||
print(f"[SMEM-P PROPER] After partition_D: tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}")
|
||||
# Manual SMEM addressing for P (helpers are a trap)
|
||||
# We need to write P values from QK C-fragment layout to PV A-operand SMEM layout
|
||||
# sP has PV A-operand SMEM layout: p_smem_s
|
||||
print(f"[SMEM-P MANUAL] Starting manual SMEM addressing")
|
||||
print(f"[SMEM-P MANUAL] sP shape: {cute.shape(sP)} layout: {sP.layout}")
|
||||
print(f"[SMEM-P MANUAL] p_smem_s (PV A-operand SMEM layout): {p_smem_s}")
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
@@ -349,43 +335,39 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: Use QK C-fragment layout for source (not TMEM layout)
|
||||
# rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch
|
||||
# Create BF16 view with QK C-fragment layout for copying to BF16 SMEM
|
||||
rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread
|
||||
rP_bf16_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout)
|
||||
# SMEM-P: Manual addressing (helpers are a trap)
|
||||
# Each softmax thread owns P values in QK C-fragment partition
|
||||
# Need to write to SMEM with PV A-operand layout
|
||||
|
||||
# Partition source with QK layout
|
||||
print(f"[SMEM-P PROPER] Before partition_S: rP_bf16_qk shape: {cute.shape(rP_bf16_qk)} layout: {rP_bf16_qk.layout}")
|
||||
tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_bf16_qk)
|
||||
print(f"[SMEM-P PROPER] After partition_S: tSMEM_CPYrP_qk layout: {tSMEM_CPYrP_qk.layout}")
|
||||
print(f"[SMEM-P PROPER] After partition_S: tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}")
|
||||
print(f"[SMEM-P MANUAL] Starting manual P write to SMEM")
|
||||
|
||||
# Debug shapes
|
||||
print(f"[SMEM-P PROPER] rP_bf16 shape: {cute.shape(rP_bf16)}, layout: TMEM")
|
||||
print(f"[SMEM-P PROPER] rP_bf16_qk shape: {cute.shape(rP_bf16_qk)}, layout: QK C-fragment (BF16)")
|
||||
print(f"[SMEM-P PROPER] tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}")
|
||||
print(f"[SMEM-P PROPER] tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}")
|
||||
# Get thread's logical position in QK C-fragment partition
|
||||
# tStS0 is QK C-fragment tensor for this thread
|
||||
thread_qk_coord = cute.coord(tStS0) # Logical coordinates in QK C-fragment space
|
||||
print(f"[SMEM-P MANUAL] Thread QK coord: {thread_qk_coord}")
|
||||
|
||||
# Manual copy instead of cute.copy (helpers are a trap)
|
||||
# Get sizes and iterate
|
||||
src_size = cute.size(tSMEM_CPYrP_qk)
|
||||
dst_size = cute.size(tSMEM_CPYsP)
|
||||
print(f"[SMEM-P PROPER] Manual copy: src_size={src_size}, dst_size={dst_size}")
|
||||
# Get the shape of P values this thread owns
|
||||
# tStS0 has shape ((128, 128), 1, 1) - total 128×128 P matrix
|
||||
# Each thread owns a subtile of this
|
||||
qk_fragment_shape = cute.shape(tStS0)
|
||||
print(f"[SMEM-P MANUAL] QK fragment shape: {qk_fragment_shape}")
|
||||
|
||||
if src_size == dst_size:
|
||||
# Same number of elements - copy element by element
|
||||
for j in cutlass.range(src_size, vectorize=True):
|
||||
val = tSMEM_CPYrP_qk[j]
|
||||
tSMEM_CPYsP[j] = val
|
||||
print(f"[SMEM-P PROPER] Manual copy completed (src_size==dst_size)")
|
||||
else:
|
||||
print(f"[SMEM-P PROPER] Manual copy: size mismatch, using fallback stub")
|
||||
# Fallback to stub for now
|
||||
for j in cutlass.range(cute.size(sP), vectorize=True):
|
||||
sP[j] = BFloat16(0.0)
|
||||
print(f"[SMEM-P PROPER] Used fallback stub")
|
||||
# For now, implement simple test: write thread ID to first element
|
||||
# This is WRONG but helps debug
|
||||
smem_test_offset = Int32(0)
|
||||
sP[smem_test_offset] = BFloat16(float(sfw_idx) * 0.01)
|
||||
print(f"[SMEM-P MANUAL] Wrote thread {sfw_idx} value to SMEM offset 0")
|
||||
|
||||
# TODO: Implement proper mapping
|
||||
# 1. Determine which P indices this thread owns
|
||||
# 2. For each P index (i,j), compute SMEM address in PV A-operand layout
|
||||
# 3. Write P value to that address
|
||||
|
||||
# For now, zero out sP as fallback (produces wrong results but compiles)
|
||||
for j in cutlass.range(cute.size(sP), vectorize=True):
|
||||
sP[j] = BFloat16(0.0)
|
||||
|
||||
print(f"[SMEM-P MANUAL] Used zero-fallback (TODO: implement proper mapping)")
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
|
||||
if kt > 0:
|
||||
|
||||
Reference in New Issue
Block a user