SMEM-P: implement full 128-value write in softmax loop using coordinate mapping

This commit is contained in:
2026-05-23 19:36:56 +00:00
parent e118ad967d
commit e09c8057be

View File

@@ -164,6 +164,13 @@ class FmhaKernel:
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
# Create coordinate tensor for QK C-fragment layout
# Each element maps to its logical coordinate ((m,n),0,0)
if self.use_smem_p:
cP_qk = cute.make_identity_tensor(tStS0.shape)
print(f"[SMEM-P CUTLASS] Created cP_qk shape: {cute.shape(cP_qk)}")
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
@@ -309,6 +316,11 @@ class FmhaKernel:
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
# Coordinate fragments for SMEM-P mapping
if self.use_smem_p:
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, cute.make_layout(frg_tile))
print(f"[SMEM-P CUTLASS] Created tTMEM_LOADcS_frg shape: {cute.shape(tTMEM_LOADcS_frg)}")
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
@@ -332,6 +344,25 @@ class FmhaKernel:
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
# If using SMEM-P, write P value directly to SMEM
if self.use_smem_p:
# Get QK coordinate for this position
qk_coord = tTMEM_LOADcS_frg[k, j]
mn = qk_coord[0]
m = mn[0]
n = mn[1]
# Map to PV SMEM coordinate
n0 = n % 16
n1 = (n // 16) % 4
n2 = n // 64
pv_coord = ((m, n0), 0, (n1, n2), 0)
# Convert Float32 → BF16 and write to SMEM
p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
sP[pv_coord] = p_val_bf16
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
@@ -341,51 +372,13 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: Manual addressing with CUTLASS LLM pattern
print(f"[SMEM-P CUTLASS] Starting manual P write to SMEM")
# Get thread index for coordinate computations
tidx, _, _ = cute.arch.thread_idx()
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
# Function to map QK coordinate to PV SMEM coordinate
# QK: ((m, n), 0, 0) → PV: ((m, n % 16), 0, ((n // 16) % 4, n // 64), 0)
def qk_to_pv_coord(m, n):
n0 = n % 16
n1 = (n // 16) % 4
n2 = n // 64
return ((m, n0), 0, (n1, n2), 0)
# Each thread handles 32×1 tile × 4 fragments = 128 P values
# We need to map each of these 128 values to SMEM
# For testing: write a simple pattern to verify mapping works
# Each thread writes to different coordinate for testing
# Use thread-relative simple coordinates
thread_offset = tidx % 16 # 0-15
test_m = thread_offset
test_n = thread_offset
test_coord = qk_to_pv_coord(test_m, test_n)
# Write constant test value (0.5)
test_val = BFloat16(0.5)
sP[test_coord] = test_val
print(f"[SMEM-P CUTLASS] Thread wrote test to coord {test_coord}")
# TODO: Implement full 128-value mapping
# Need to:
# 1. Create coordinate tensor with make_identity_tensor(tStS0.shape)
# 2. Partition it the same way as rP_bf16
# 3. For each of the 128 P values, get its QK coordinate
# 4. Map to PV coordinate using qk_to_pv_coord
# 5. Write to sP[dst_coord]
# For now, zero rest of sP (except our test value)
# This is WRONG but allows compilation
print(f"[SMEM-P CUTLASS] WARNING: Only wrote test value, rest zeroed (incomplete)")
# SMEM-P: Already wrote P values to SMEM in softmax loop
# Just need fence and barrier
print(f"[SMEM-P CUTLASS] P values already written to SMEM, proceeding to fence")
cute.arch.fence_proxy("async.shared", space="cta")
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
# Barrier for both TMEM-P and SMEM-P paths
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
if kt > 0:
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype