D1.3: Replace NO-op TMEM round-trip with correction_epilog using epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
- Remove hand-constructed TMEM round-trips (3% layout mismatch error) - Use CUTLASS get_tmem_load_op + get_smem_store_op paired atoms - One-way trip: TMEM -> reg (normalize) -> SMEM -> GMEM - SMEM-P path: zero-fill stub (proper copy TBD) - Keep per-tile O rescale atoms for n>128 support
This commit is contained in:
@@ -9,6 +9,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
import math
|
||||
@@ -22,9 +23,6 @@ class FmhaKernel:
|
||||
self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256
|
||||
self.n_pv_tiles = head_dim // self.pv_n_tile
|
||||
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
|
||||
self.debug_p_one = False # DEBUG: write constant P=1.0 to verify mapping
|
||||
self.debug_swap_mn = False # DEBUG: try swapping m and n0 in coordinate mapping
|
||||
self.debug_permute = 4 # DEBUG: try different coordinate permutations (4=swap m↔n2)
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
@@ -167,13 +165,6 @@ class FmhaKernel:
|
||||
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
# Create coordinate tensor for QK C-fragment layout
|
||||
# Each element maps to its logical coordinate ((m,n),0,0)
|
||||
if self.use_smem_p:
|
||||
cP_qk = cute.make_identity_tensor(tStS0.shape)
|
||||
print(f"[SMEM-P CUTLASS] Created cP_qk shape: {cute.shape(cP_qk)}")
|
||||
|
||||
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
@@ -184,14 +175,6 @@ class FmhaKernel:
|
||||
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
|
||||
tOrP = tOrP_base[(None,None,None,0)]
|
||||
tCrP = pv_mma.make_fragment_A(sP)
|
||||
if self.use_smem_p:
|
||||
print(f"[SMEM-P DEBUG] tCrP shape: {cute.shape(tCrP)} layout: {tCrP.layout}")
|
||||
# DEBUG: compute iterator offset between tCrP and sP
|
||||
try:
|
||||
offset_elems = tCrP.iterator - sP.iterator
|
||||
print(f"[SMEM-P DEBUG] tCrP iterator offset: {offset_elems}")
|
||||
except:
|
||||
print(f"[SMEM-P DEBUG] iterator offset not available")
|
||||
# tOrP0 always defined as tOrP. The TMEM-P path in the MMA warp applies
|
||||
# the p0 column offset inline when constructing the gemm arguments.
|
||||
tOrP0 = tOrP
|
||||
@@ -275,18 +258,10 @@ class FmhaKernel:
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# Manual SMEM addressing for P (CUTLASS LLM guidance)
|
||||
# We need to write P values from QK C-fragment layout to PV A-operand SMEM layout
|
||||
# sP has PV A-operand SMEM layout: p_smem_s
|
||||
print(f"[SMEM-P CUTLASS] Starting manual SMEM addressing with CUTLASS LLM pattern")
|
||||
print(f"[SMEM-P CUTLASS] sP shape: {cute.shape(sP)} layout: {sP.layout}")
|
||||
|
||||
# Get thread index for coordinate partitioning
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
lane_idx = tidx % 32
|
||||
|
||||
print(f"[SMEM-P CUTLASS] tidx={tidx}, warp_idx={warp_idx}, lane_idx={lane_idx}")
|
||||
# P SMEM copy atoms: SMEM-P
|
||||
# NOTE: make_tiled_copy_C fails (incompatible QK C-fragment vs PV A-operand layouts).
|
||||
# SMEM-P proper copy is TBD. For now, SMEM-P path zero-fills sP.
|
||||
# The TMEM-P path (hd<=64) works correctly without SMEM-P.
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
@@ -326,20 +301,7 @@ class FmhaKernel:
|
||||
old_row_max = row_max
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
# Compute fragment tile size dynamically (must match value division)
|
||||
frg_tile_size = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
frg_layout = cute.make_layout(frg_tile_size)
|
||||
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, frg_layout)
|
||||
# Coordinate fragments for SMEM-P mapping (needed unconditionally for scoping)
|
||||
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, frg_layout)
|
||||
if self.use_smem_p:
|
||||
print(f"[SMEM-P CUTLASS] Created tTMEM_LOADcS_frg shape: {cute.shape(tTMEM_LOADcS_frg)}")
|
||||
print(f"[SMEM-P CUTLASS] tTMEM_LOADrS shape: {cute.shape(tTMEM_LOADrS)}")
|
||||
print(f"[SMEM-P CUTLASS] tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
|
||||
print(f"[SMEM-P CUTLASS] frg_tile_size: {frg_tile_size}, frg_layout: {frg_layout}")
|
||||
print(f"[SMEM-P CUTLASS] tTMEM_LOADrS_frg shape: {cute.shape(tTMEM_LOADrS_frg)}")
|
||||
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
|
||||
@@ -359,7 +321,6 @@ class FmhaKernel:
|
||||
minus_row_max = Float32(0.0) - row_max_safe
|
||||
|
||||
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
|
||||
# Phase 1: Compute exp values and accumulate row_sum
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
|
||||
@@ -367,104 +328,18 @@ class FmhaKernel:
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
# Compute inverse row sum for normalization
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
# DEBUG: If debug flag set, write constant P=1.0 to verify mapping
|
||||
if self.debug_p_one:
|
||||
inv_row_sum = Float32(1.0)
|
||||
print("[DEBUG] Writing constant P=1.0 to verify SMEM mapping")
|
||||
|
||||
# Phase 2: Normalize P values and write to SMEM (if using SMEM-P)
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
# Get normalized P value
|
||||
p_val = tTMEM_LOADrS_frg[k, j] * inv_row_sum
|
||||
|
||||
if self.use_smem_p:
|
||||
# Get QK coordinate for this position
|
||||
qk_coord = tTMEM_LOADcS_frg[k, j]
|
||||
# qk_coord is (m, n) coordinate
|
||||
m = qk_coord[0]
|
||||
n = qk_coord[1]
|
||||
|
||||
# Map to PV SMEM coordinate
|
||||
# Convert to local coordinates (0-127) as sanity check
|
||||
m_local = m % 128
|
||||
n_local = n % 128
|
||||
|
||||
# Original mapping formula (should be correct for local coords)
|
||||
n0 = n_local % 16
|
||||
n1 = (n_local // 16) % 4
|
||||
n2 = n_local // 64
|
||||
|
||||
# DEBUG: Try different permutations to find correct mapping
|
||||
# coords = [m_local, n0, n1, n2]
|
||||
# Permutation 0: (m, n0, n1, n2) original
|
||||
# Permutation 1: (n0, m, n1, n2) swap m↔n0
|
||||
# Permutation 2: (m, n1, n0, n2) swap n0↔n1
|
||||
# Permutation 3: (m, n0, n2, n1) swap n1↔n2
|
||||
# Permutation 4: (n2, n0, n1, m) swap m↔n2
|
||||
# Permutation 5: (n1, n0, m, n2) swap m↔n1
|
||||
# Permutation 6: (n0, n1, n2, m) rotate right
|
||||
# Permutation 7: (n2, n1, n0, m) reverse
|
||||
if self.debug_permute == 0:
|
||||
a,b,c,d = m_local, n0, n1, n2
|
||||
elif self.debug_permute == 1:
|
||||
a,b,c,d = n0, m_local, n1, n2
|
||||
elif self.debug_permute == 2:
|
||||
a,b,c,d = m_local, n1, n0, n2
|
||||
elif self.debug_permute == 3:
|
||||
a,b,c,d = m_local, n0, n2, n1
|
||||
elif self.debug_permute == 4:
|
||||
a,b,c,d = n2, n0, n1, m_local
|
||||
elif self.debug_permute == 5:
|
||||
a,b,c,d = n1, n0, m_local, n2
|
||||
elif self.debug_permute == 6:
|
||||
a,b,c,d = n0, n1, n2, m_local
|
||||
elif self.debug_permute == 7:
|
||||
a,b,c,d = n2, n1, n0, m_local
|
||||
else:
|
||||
a,b,c,d = m_local, n0, n1, n2
|
||||
|
||||
pv_coord = ((a, b), 0, (c, d), 0)
|
||||
|
||||
# Write normalized P value
|
||||
p_val_bf16 = p_val.to(self.q_dtype)
|
||||
sP[pv_coord] = p_val_bf16 # Tensor indexing
|
||||
|
||||
# DEBUG: Print first few coordinates to verify mapping
|
||||
if self.use_smem_p and k < 2 and j < 2:
|
||||
print(f"[SMEM-P DEBUG] k={k}, j={j}, qk_coord=({m},{n}), pv_coord={pv_coord}")
|
||||
# Try to compute offset using crd2idx
|
||||
try:
|
||||
offset = cute.crd2idx(pv_coord, sP.layout)
|
||||
print(f"[SMEM-P DEBUG] offset = {offset}")
|
||||
except:
|
||||
print(f"[SMEM-P DEBUG] crd2idx not available")
|
||||
else:
|
||||
# For TMEM-P, store normalized P to register buffer
|
||||
rP_bf16_frg[k, j] = p_val.to(self.q_dtype)
|
||||
|
||||
if not self.use_smem_p:
|
||||
# TMEM-P: store P to TMEM via register bridge
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: Already wrote P values to SMEM in softmax loop
|
||||
# Just need fence and barrier
|
||||
print(f"[SMEM-P CUTLASS] P values already written to SMEM, proceeding to fence")
|
||||
|
||||
# DEBUG: Compute offset for known coordinate to verify mapping
|
||||
test_coord = ((0,0), 0, (0,0), 0)
|
||||
test_offset = cute.crd2idx(test_coord, sP.layout)
|
||||
print(f"[SMEM-P DEBUG] test_coord {test_coord} -> offset {test_offset}")
|
||||
|
||||
# SMEM-P: zero-fill sP (proper SMEM-P copy TBD)
|
||||
# The TMEM-P path works for hd<=64. SMEM-P needs layout-aware copy.
|
||||
for j in cutlass.range(cute.size(sP), vectorize=True):
|
||||
sP[j] = self.q_dtype(0)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
|
||||
# Barrier for both TMEM-P and SMEM-P paths
|
||||
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
|
||||
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
@@ -495,68 +370,66 @@ class FmhaKernel:
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout ===
|
||||
tTMrO_noop = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO_noop[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# === Final O normalization: O *= 1/row_sum ===
|
||||
# === Correction epilog: one-way TMEM -> reg -> SMEM -> GMEM ===
|
||||
# Uses epilogue_tmem_copy_and_partition (get_tmem_load_op) for correct TMEM read.
|
||||
# Uses epilogue_smem_copy_and_partition (get_smem_store_op) for correct SMEM write.
|
||||
# No TMEM round-trip. No layout mismatch. No 3% error.
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout
|
||||
)
|
||||
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Epilogue: TMEM → SMEM → GMEM via TMA store.
|
||||
# Set up the TMEM→reg and reg→SMEM copy atoms using CUTLASS helpers
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
tiled_copy_t2r, tTR_tAcc, tTR_rAcc = epilogue_tmem_copy_and_partition(
|
||||
self, sfw_idx, tCtO_base, tCgC, epi_tile, self.use_2cta_instrs
|
||||
)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
|
||||
0, const_expr(lambda x: x), (0, 0, 0),
|
||||
acc_cons_st, acc_pipe, c_pipe,
|
||||
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
|
||||
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
|
||||
self, tiled_copy_t2r, tTR_rC, sfw_idx, sC
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
|
||||
# Wait for accumulator buffer
|
||||
acc_pipe.consumer_wait(pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage))
|
||||
|
||||
# Process each subtile: TMEM load -> normalize -> BF16 convert -> SMEM store
|
||||
tTR_tAcc_g = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
||||
subtile_cnt = cute.size(tTR_tAcc_g.shape, mode=[3])
|
||||
for subtile_idx in range(subtile_cnt):
|
||||
tTR_tAcc_mn = tTR_tAcc_g[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
||||
|
||||
# Normalize: O *= 1/row_sum
|
||||
for j in cutlass.range(cute.size(tTR_rAcc), vectorize=True):
|
||||
tTR_rAcc[j] = tTR_rAcc[j] * inv_row_sum
|
||||
|
||||
# Convert FP32 -> BF16
|
||||
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
|
||||
tRS_rC.store(acc_vec.to(self.c_dtype))
|
||||
|
||||
# Store to SMEM
|
||||
c_buffer = subtile_idx % self.num_c_stage
|
||||
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)])
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
|
||||
# TMA store from SMEM to GMEM
|
||||
# Partition sC and gC for TMA store
|
||||
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
# Only warp 0 of epilogue issues TMA store
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, subtile_idx)])
|
||||
# Sync after TMA store
|
||||
epilog_sync_bar = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
epilog_sync_bar.arrive_and_wait()
|
||||
|
||||
# Release accumulator buffer
|
||||
with cute.arch.elect_one():
|
||||
acc_pipe.consumer_release(pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage))
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
Reference in New Issue
Block a user