D1.3: Replace NO-op TMEM round-trip with correction_epilog using epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition

- Remove hand-constructed TMEM round-trips (3% layout mismatch error)
- Use CUTLASS get_tmem_load_op + get_smem_store_op paired atoms
- One-way trip: TMEM -> reg (normalize) -> SMEM -> GMEM
- SMEM-P path: zero-fill stub (proper copy TBD)
- Keep per-tile O rescale atoms for n>128 support
This commit is contained in:
2026-05-23 20:50:23 +00:00
parent 1e55e36919
commit 1cf7140ea3

View File

@@ -9,6 +9,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
@@ -22,9 +23,6 @@ class FmhaKernel:
self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.debug_p_one = False # DEBUG: write constant P=1.0 to verify mapping
self.debug_swap_mn = False # DEBUG: try swapping m and n0 in coordinate mapping
self.debug_permute = 4 # DEBUG: try different coordinate permutations (4=swap m↔n2)
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
@@ -167,13 +165,6 @@ class FmhaKernel:
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
# Create coordinate tensor for QK C-fragment layout
# Each element maps to its logical coordinate ((m,n),0,0)
if self.use_smem_p:
cP_qk = cute.make_identity_tensor(tStS0.shape)
print(f"[SMEM-P CUTLASS] Created cP_qk shape: {cute.shape(cP_qk)}")
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
@@ -184,14 +175,6 @@ class FmhaKernel:
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
tOrP = tOrP_base[(None,None,None,0)]
tCrP = pv_mma.make_fragment_A(sP)
if self.use_smem_p:
print(f"[SMEM-P DEBUG] tCrP shape: {cute.shape(tCrP)} layout: {tCrP.layout}")
# DEBUG: compute iterator offset between tCrP and sP
try:
offset_elems = tCrP.iterator - sP.iterator
print(f"[SMEM-P DEBUG] tCrP iterator offset: {offset_elems}")
except:
print(f"[SMEM-P DEBUG] iterator offset not available")
# tOrP0 always defined as tOrP. The TMEM-P path in the MMA warp applies
# the p0 column offset inline when constructing the gemm arguments.
tOrP0 = tOrP
@@ -275,18 +258,10 @@ class FmhaKernel:
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# Manual SMEM addressing for P (CUTLASS LLM guidance)
# We need to write P values from QK C-fragment layout to PV A-operand SMEM layout
# sP has PV A-operand SMEM layout: p_smem_s
print(f"[SMEM-P CUTLASS] Starting manual SMEM addressing with CUTLASS LLM pattern")
print(f"[SMEM-P CUTLASS] sP shape: {cute.shape(sP)} layout: {sP.layout}")
# Get thread index for coordinate partitioning
tidx, _, _ = cute.arch.thread_idx()
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
lane_idx = tidx % 32
print(f"[SMEM-P CUTLASS] tidx={tidx}, warp_idx={warp_idx}, lane_idx={lane_idx}")
# P SMEM copy atoms: SMEM-P
# NOTE: make_tiled_copy_C fails (incompatible QK C-fragment vs PV A-operand layouts).
# SMEM-P proper copy is TBD. For now, SMEM-P path zero-fills sP.
# The TMEM-P path (hd<=64) works correctly without SMEM-P.
row_max = -Float32.inf
row_sum = Float32(0.0)
@@ -326,20 +301,7 @@ class FmhaKernel:
old_row_max = row_max
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
# Compute fragment tile size dynamically (must match value division)
frg_tile_size = cute.size(tTMEM_LOADrS) // frg_cnt
frg_layout = cute.make_layout(frg_tile_size)
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, frg_layout)
# Coordinate fragments for SMEM-P mapping (needed unconditionally for scoping)
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, frg_layout)
if self.use_smem_p:
print(f"[SMEM-P CUTLASS] Created tTMEM_LOADcS_frg shape: {cute.shape(tTMEM_LOADcS_frg)}")
print(f"[SMEM-P CUTLASS] tTMEM_LOADrS shape: {cute.shape(tTMEM_LOADrS)}")
print(f"[SMEM-P CUTLASS] tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
print(f"[SMEM-P CUTLASS] frg_tile_size: {frg_tile_size}, frg_layout: {frg_layout}")
print(f"[SMEM-P CUTLASS] tTMEM_LOADrS_frg shape: {cute.shape(tTMEM_LOADrS_frg)}")
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
@@ -359,7 +321,6 @@ class FmhaKernel:
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
# Phase 1: Compute exp values and accumulate row_sum
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
@@ -367,104 +328,18 @@ class FmhaKernel:
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
# Compute inverse row sum for normalization
inv_row_sum = Float32(1.0) / row_sum
# DEBUG: If debug flag set, write constant P=1.0 to verify mapping
if self.debug_p_one:
inv_row_sum = Float32(1.0)
print("[DEBUG] Writing constant P=1.0 to verify SMEM mapping")
# Phase 2: Normalize P values and write to SMEM (if using SMEM-P)
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
# Get normalized P value
p_val = tTMEM_LOADrS_frg[k, j] * inv_row_sum
if self.use_smem_p:
# Get QK coordinate for this position
qk_coord = tTMEM_LOADcS_frg[k, j]
# qk_coord is (m, n) coordinate
m = qk_coord[0]
n = qk_coord[1]
# Map to PV SMEM coordinate
# Convert to local coordinates (0-127) as sanity check
m_local = m % 128
n_local = n % 128
# Original mapping formula (should be correct for local coords)
n0 = n_local % 16
n1 = (n_local // 16) % 4
n2 = n_local // 64
# DEBUG: Try different permutations to find correct mapping
# coords = [m_local, n0, n1, n2]
# Permutation 0: (m, n0, n1, n2) original
# Permutation 1: (n0, m, n1, n2) swap m↔n0
# Permutation 2: (m, n1, n0, n2) swap n0↔n1
# Permutation 3: (m, n0, n2, n1) swap n1↔n2
# Permutation 4: (n2, n0, n1, m) swap m↔n2
# Permutation 5: (n1, n0, m, n2) swap m↔n1
# Permutation 6: (n0, n1, n2, m) rotate right
# Permutation 7: (n2, n1, n0, m) reverse
if self.debug_permute == 0:
a,b,c,d = m_local, n0, n1, n2
elif self.debug_permute == 1:
a,b,c,d = n0, m_local, n1, n2
elif self.debug_permute == 2:
a,b,c,d = m_local, n1, n0, n2
elif self.debug_permute == 3:
a,b,c,d = m_local, n0, n2, n1
elif self.debug_permute == 4:
a,b,c,d = n2, n0, n1, m_local
elif self.debug_permute == 5:
a,b,c,d = n1, n0, m_local, n2
elif self.debug_permute == 6:
a,b,c,d = n0, n1, n2, m_local
elif self.debug_permute == 7:
a,b,c,d = n2, n1, n0, m_local
else:
a,b,c,d = m_local, n0, n1, n2
pv_coord = ((a, b), 0, (c, d), 0)
# Write normalized P value
p_val_bf16 = p_val.to(self.q_dtype)
sP[pv_coord] = p_val_bf16 # Tensor indexing
# DEBUG: Print first few coordinates to verify mapping
if self.use_smem_p and k < 2 and j < 2:
print(f"[SMEM-P DEBUG] k={k}, j={j}, qk_coord=({m},{n}), pv_coord={pv_coord}")
# Try to compute offset using crd2idx
try:
offset = cute.crd2idx(pv_coord, sP.layout)
print(f"[SMEM-P DEBUG] offset = {offset}")
except:
print(f"[SMEM-P DEBUG] crd2idx not available")
else:
# For TMEM-P, store normalized P to register buffer
rP_bf16_frg[k, j] = p_val.to(self.q_dtype)
if not self.use_smem_p:
# TMEM-P: store P to TMEM via register bridge
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: Already wrote P values to SMEM in softmax loop
# Just need fence and barrier
print(f"[SMEM-P CUTLASS] P values already written to SMEM, proceeding to fence")
# DEBUG: Compute offset for known coordinate to verify mapping
test_coord = ((0,0), 0, (0,0), 0)
test_offset = cute.crd2idx(test_coord, sP.layout)
print(f"[SMEM-P DEBUG] test_coord {test_coord} -> offset {test_offset}")
# SMEM-P: zero-fill sP (proper SMEM-P copy TBD)
# The TMEM-P path works for hd<=64. SMEM-P needs layout-aware copy.
for j in cutlass.range(cute.size(sP), vectorize=True):
sP[j] = self.q_dtype(0)
cute.arch.fence_proxy("async.shared", space="cta")
# Barrier for both TMEM-P and SMEM-P paths
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
if kt > 0:
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
@@ -495,68 +370,66 @@ class FmhaKernel:
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout ===
tTMrO_noop = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO_noop[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# === Final O normalization: O *= 1/row_sum ===
# === Correction epilog: one-way TMEM -> reg -> SMEM -> GMEM ===
# Uses epilogue_tmem_copy_and_partition (get_tmem_load_op) for correct TMEM read.
# Uses epilogue_smem_copy_and_partition (get_smem_store_op) for correct SMEM write.
# No TMEM round-trip. No layout mismatch. No 3% error.
inv_row_sum = Float32(1.0) / row_sum
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Epilogue: TMEM → SMEM → GMEM via TMA store.
# Set up the TMEM→reg and reg→SMEM copy atoms using CUTLASS helpers
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
tiled_copy_t2r, tTR_tAcc, tTR_rAcc = epilogue_tmem_copy_and_partition(
self, sfw_idx, tCtO_base, tCgC, epi_tile, self.use_2cta_instrs
)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
0, const_expr(lambda x: x), (0, 0, 0),
acc_cons_st, acc_pipe, c_pipe,
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
self, tiled_copy_t2r, tTR_rC, sfw_idx, sC
)
c_pipe.producer_tail()
# Wait for accumulator buffer
acc_pipe.consumer_wait(pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage))
# Process each subtile: TMEM load -> normalize -> BF16 convert -> SMEM store
tTR_tAcc_g = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
subtile_cnt = cute.size(tTR_tAcc_g.shape, mode=[3])
for subtile_idx in range(subtile_cnt):
tTR_tAcc_mn = tTR_tAcc_g[(None, None, None, subtile_idx)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
# Normalize: O *= 1/row_sum
for j in cutlass.range(cute.size(tTR_rAcc), vectorize=True):
tTR_rAcc[j] = tTR_rAcc[j] * inv_row_sum
# Convert FP32 -> BF16
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
tRS_rC.store(acc_vec.to(self.c_dtype))
# Store to SMEM
c_buffer = subtile_idx % self.num_c_stage
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)])
cute.arch.fence_proxy("async.shared", space="cta")
# TMA store from SMEM to GMEM
# Partition sC and gC for TMA store
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
cute.group_modes(sC, 0, 2),
cute.group_modes(tCgC_epi, 0, 2),
)
# Only warp 0 of epilogue issues TMA store
if warp_idx == self.epilogue_warp_id[0]:
cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, subtile_idx)])
# Sync after TMA store
epilog_sync_bar = pipeline.NamedBarrier(
barrier_id=self.epilog_sync_bar_id,
num_threads=32 * len(self.epilogue_warp_id),
)
epilog_sync_bar.arrive_and_wait()
# Release accumulator buffer
with cute.arch.elect_one():
acc_pipe.consumer_release(pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage))
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)