SMEM-P: implement full 128-value write in softmax loop using coordinate mapping
This commit is contained in:
@@ -164,6 +164,13 @@ class FmhaKernel:
|
||||
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
# Create coordinate tensor for QK C-fragment layout
|
||||
# Each element maps to its logical coordinate ((m,n),0,0)
|
||||
if self.use_smem_p:
|
||||
cP_qk = cute.make_identity_tensor(tStS0.shape)
|
||||
print(f"[SMEM-P CUTLASS] Created cP_qk shape: {cute.shape(cP_qk)}")
|
||||
|
||||
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
@@ -309,6 +316,11 @@ class FmhaKernel:
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
# Coordinate fragments for SMEM-P mapping
|
||||
if self.use_smem_p:
|
||||
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, cute.make_layout(frg_tile))
|
||||
print(f"[SMEM-P CUTLASS] Created tTMEM_LOADcS_frg shape: {cute.shape(tTMEM_LOADcS_frg)}")
|
||||
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
|
||||
@@ -332,6 +344,25 @@ class FmhaKernel:
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
|
||||
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
|
||||
|
||||
# If using SMEM-P, write P value directly to SMEM
|
||||
if self.use_smem_p:
|
||||
# Get QK coordinate for this position
|
||||
qk_coord = tTMEM_LOADcS_frg[k, j]
|
||||
mn = qk_coord[0]
|
||||
m = mn[0]
|
||||
n = mn[1]
|
||||
|
||||
# Map to PV SMEM coordinate
|
||||
n0 = n % 16
|
||||
n1 = (n // 16) % 4
|
||||
n2 = n // 64
|
||||
pv_coord = ((m, n0), 0, (n1, n2), 0)
|
||||
|
||||
# Convert Float32 → BF16 and write to SMEM
|
||||
p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
|
||||
sP[pv_coord] = p_val_bf16
|
||||
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
@@ -341,51 +372,13 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: Manual addressing with CUTLASS LLM pattern
|
||||
print(f"[SMEM-P CUTLASS] Starting manual P write to SMEM")
|
||||
|
||||
# Get thread index for coordinate computations
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
|
||||
# Function to map QK coordinate to PV SMEM coordinate
|
||||
# QK: ((m, n), 0, 0) → PV: ((m, n % 16), 0, ((n // 16) % 4, n // 64), 0)
|
||||
def qk_to_pv_coord(m, n):
|
||||
n0 = n % 16
|
||||
n1 = (n // 16) % 4
|
||||
n2 = n // 64
|
||||
return ((m, n0), 0, (n1, n2), 0)
|
||||
|
||||
# Each thread handles 32×1 tile × 4 fragments = 128 P values
|
||||
# We need to map each of these 128 values to SMEM
|
||||
|
||||
# For testing: write a simple pattern to verify mapping works
|
||||
# Each thread writes to different coordinate for testing
|
||||
# Use thread-relative simple coordinates
|
||||
thread_offset = tidx % 16 # 0-15
|
||||
test_m = thread_offset
|
||||
test_n = thread_offset
|
||||
test_coord = qk_to_pv_coord(test_m, test_n)
|
||||
|
||||
# Write constant test value (0.5)
|
||||
test_val = BFloat16(0.5)
|
||||
sP[test_coord] = test_val
|
||||
print(f"[SMEM-P CUTLASS] Thread wrote test to coord {test_coord}")
|
||||
|
||||
# TODO: Implement full 128-value mapping
|
||||
# Need to:
|
||||
# 1. Create coordinate tensor with make_identity_tensor(tStS0.shape)
|
||||
# 2. Partition it the same way as rP_bf16
|
||||
# 3. For each of the 128 P values, get its QK coordinate
|
||||
# 4. Map to PV coordinate using qk_to_pv_coord
|
||||
# 5. Write to sP[dst_coord]
|
||||
|
||||
# For now, zero rest of sP (except our test value)
|
||||
# This is WRONG but allows compilation
|
||||
print(f"[SMEM-P CUTLASS] WARNING: Only wrote test value, rest zeroed (incomplete)")
|
||||
|
||||
# SMEM-P: Already wrote P values to SMEM in softmax loop
|
||||
# Just need fence and barrier
|
||||
print(f"[SMEM-P CUTLASS] P values already written to SMEM, proceeding to fence")
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
|
||||
|
||||
# Barrier for both TMEM-P and SMEM-P paths
|
||||
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
|
||||
Reference in New Issue
Block a user