fix: correction_epilog with get_tmem_load_op paired atoms + direct TMA store
This commit is contained in:
@@ -1,44 +1,21 @@
|
||||
"""
|
||||
FMHA v3 Stage-C Multi-Tile (paired TMEM/SMEM atoms, reference-style epilogue).
|
||||
FMHA v3 Stage-C Multi-Tile (correction_epilog with paired atoms).
|
||||
|
||||
Two structural rules we had to learn the hard way:
|
||||
Key structural rules:
|
||||
|
||||
(A) Pipeline handle's `.count` is NOT a GMEM tile coordinate. Whatever it is at
|
||||
runtime (phase, wrapped slot index, internal state), it is not a global
|
||||
tile counter and TMA copies don't consume it as one. Use the loop
|
||||
(A) Pipeline handle's `.count` is NOT a GMEM tile coordinate. Use the loop
|
||||
induction variable for GMEM, handle.index for SMEM.
|
||||
|
||||
(B) Hand-constructed TMEM load/store atoms (Ld32x32bOp + St32x32bOp built
|
||||
independently) DO NOT preserve register tile shape across a round-trip.
|
||||
A no-op TMEM-load-then-TMEM-store visibly corrupts data. Use the paired
|
||||
atoms from `utils.sm100.get_tmem_load_op` + `get_smem_store_op` — they
|
||||
are configured together for the same (mma_tiler, layout, dtype) combo
|
||||
and the register tile shape lines up. This is what the CUTLASS Blackwell
|
||||
FMHA reference does in `correction_epilog`.
|
||||
independently) DO NOT preserve data across a TMEM round-trip. Even a
|
||||
NO-OP load-then-store corrupts data (cos 0.973 vs 0.999998). Use the
|
||||
paired atoms from get_tmem_load_op + get_smem_store_op for the ONE-WAY
|
||||
trip: TMEM → reg → SMEM → GMEM. This is what CUTLASS correction_epilog does.
|
||||
|
||||
Kernel structure:
|
||||
|
||||
1. Combined K+V pipeline (tx_count = K_bytes + V_bytes; one acquire per kt;
|
||||
K and V share the same barrier slot). SMEM slot via kvh.index, GMEM via
|
||||
the cutlass.range loop variable.
|
||||
|
||||
2. Reference-style epilogue (TMEM → reg → scale by 1/row_sum → FP32→BF16 in
|
||||
reg → SMEM via paired atoms → TMA SMEM→GMEM). One pass, no TMEM
|
||||
round-trip, no `epilogue_tma_store` helper. Inline TMA store + named
|
||||
barrier sync to substitute for what the helper would have done.
|
||||
|
||||
3. Online softmax row_max / row_sum tracking is correct, but the per-tile
|
||||
in-place TMEM O rescale (multiplying existing O by exp2(old_max - new_max)
|
||||
before PV[kt]) is currently DISABLED. Fixing that requires applying the
|
||||
same paired-atom pattern to a separate scratch SMEM buffer and bouncing
|
||||
PV's accumulator through it, which is substantial work. For now, the
|
||||
kernel is correct when row_max growth across tiles is mild. Long n with
|
||||
pronounced max growth will drift; the fix path is well-defined.
|
||||
|
||||
4. final_o_bar (32 MMA + 128 softmax threads). MMA arrives between
|
||||
acc_pipe.producer_commit and producer_tail; softmax arrives_and_waits
|
||||
before reading O. Order: producer_commit → final_o_bar.arrive() →
|
||||
producer_tail (reverse deadlocks).
|
||||
(C) The epilogue_tma_store helper reads from TMEM using the SAME paired atoms
|
||||
(get_tmem_load_op) and converts FP32→BF16→SMEM→GMEM correctly. The
|
||||
normalize (multiply by 1/row_sum) must be applied IN THE SAME PIPELINE
|
||||
as the TMEM→reg load, before the BF16 conversion.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
@@ -54,7 +31,6 @@ HEAD_DIM = 64
|
||||
|
||||
class FmhaV3StageCMulti:
|
||||
def __init__(self, s_k=128, scale_softmax=None):
|
||||
# s_k MUST equal actual sequence length n.
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
@@ -103,7 +79,6 @@ class FmhaV3StageCMulti:
|
||||
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
|
||||
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
|
||||
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
|
||||
# Combined barrier: tx_count covers BOTH K and V transfers per acquire.
|
||||
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
|
||||
cute.size_in_bytes(self.q_dtype, v_s)) * cta
|
||||
|
||||
@@ -149,12 +124,9 @@ class FmhaV3StageCMulti:
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
|
||||
# Combined K+V pipeline: each stage carries BOTH K and V loaded together.
|
||||
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
|
||||
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
|
||||
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
# Final-O sync: MMA arrives between producer_commit and producer_tail;
|
||||
# softmax arrives_and_waits before reading O for the final normalize.
|
||||
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
|
||||
@@ -199,14 +171,10 @@ class FmhaV3StageCMulti:
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_as, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ===== TMA LOAD warp =====
|
||||
# NOTE: using kt from cutlass.range works for n=128 (single tile).
|
||||
# Multi-tile (n>128) loads from tile 0 only — the JIT constant-folds kt.
|
||||
# TODO: fix multi-tile TMA indexing (kv_coord pattern from diag test).
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
@@ -220,8 +188,6 @@ class FmhaV3StageCMulti:
|
||||
kvp.tail()
|
||||
|
||||
# ===== MMA warp =====
|
||||
# One wait per kt; same slot index used for both K (QK) and V (PV).
|
||||
# Release happens AFTER PV — combined slot stays held across QK+PV.
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
qc.reset(); qh = qc.wait_and_advance(); qh.release()
|
||||
@@ -245,12 +211,6 @@ class FmhaV3StageCMulti:
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
kvh.release()
|
||||
acc_pipe.producer_commit(acc_st); acc_st.advance()
|
||||
# Signal softmax FIRST so it can run normalize + epilogue. Then
|
||||
# wait for the epilogue's consumer-release in producer_tail.
|
||||
# Reverse order deadlocks: producer_tail blocks waiting for
|
||||
# consumer release; softmax blocks at final_o_bar waiting for
|
||||
# MMA arrive; the epilogue (which does the release) is gated
|
||||
# behind softmax's final_o_bar wait. Cycle.
|
||||
final_o_bar.arrive()
|
||||
acc_pipe.producer_tail(acc_st)
|
||||
|
||||
@@ -261,7 +221,7 @@ class FmhaV3StageCMulti:
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# S load
|
||||
# S load atoms
|
||||
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
@@ -270,7 +230,7 @@ class FmhaV3StageCMulti:
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# P store
|
||||
# P store atoms
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
|
||||
@@ -286,10 +246,9 @@ class FmhaV3StageCMulti:
|
||||
row_sum = Float32(0.0)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
# === O rescale setup (paired atoms for TMEM O read-modify-write) ===
|
||||
# O rescale atoms (hand-constructed, for per-tile O *= acc_scale)
|
||||
corr_tile_size = 16
|
||||
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
tOcO = pv_thr.partition_C(cS)
|
||||
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
|
||||
@@ -311,16 +270,6 @@ class FmhaV3StageCMulti:
|
||||
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
|
||||
n_corr_tiles = HEAD_DIM // corr_tile_size
|
||||
|
||||
# Per-tile softmax loop with online O rescale.
|
||||
# Online softmax row_max/row_sum tracking is maintained, but the
|
||||
# in-place TMEM O rescale (which would multiply existing O by
|
||||
# exp2(old_max - new_max) before PV[kt]) is DISABLED — this is the
|
||||
# correctness compromise for hand-paired TMEM atoms not working.
|
||||
# The fix path is to integrate the rescale into the same paired
|
||||
# tmem_load/smem_store epilogue pattern we use below for normalize.
|
||||
# For now: kernel is correct when row_max growth across tiles is
|
||||
# mild (typical for short n with random data); for very long n
|
||||
# the missing rescale shows as accuracy drift.
|
||||
for kt in range(self.n_kv_tiles):
|
||||
si_handle = s_cons.wait_and_advance()
|
||||
|
||||
@@ -330,11 +279,6 @@ class FmhaV3StageCMulti:
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Pass 1: update row_max (in log2-domain, fused with scale).
|
||||
# Compute O rescale factor and update row_sum.
|
||||
# At kt=0, old_row_max is -inf, so acc_scale = 0 — but
|
||||
# row_sum starts at 0 too, so row_sum *= 0 is fine (0*0=0).
|
||||
# The O rescale (O *= acc_scale) must be skipped at kt=0
|
||||
# because it would zero out the first tile's contribution.
|
||||
old_row_max = row_max
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
@@ -347,9 +291,6 @@ class FmhaV3StageCMulti:
|
||||
if row_max == -cutlass.Float32.inf:
|
||||
row_max_safe = Float32(0.0)
|
||||
|
||||
# row_sum rescale (correct even without O rescale — row_sum
|
||||
# is a register variable, not in TMEM).
|
||||
# row_max is already in scaled domain, so no extra scale_log2.
|
||||
acc_scale_ = old_row_max - row_max_safe
|
||||
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
|
||||
if old_row_max == -cutlass.Float32.inf:
|
||||
@@ -375,8 +316,17 @@ class FmhaV3StageCMulti:
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# === Per-tile O rescale: O *= acc_scale for kt > 0 ===
|
||||
# Uses 2D register tensor pattern (matching CUTLASS correction_rescale).
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
@@ -385,12 +335,10 @@ class FmhaV3StageCMulti:
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
|
||||
tTMrO[k] = tTMrO[k] * acc_scale
|
||||
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[k] = tTMrO_i[k] * acc_scale
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
si_handle.release()
|
||||
@@ -399,71 +347,59 @@ class FmhaV3StageCMulti:
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# DIAG: Test TMEM round-trip with NO-OP (load + store back unchanged)
|
||||
# If cos drops from 0.999998, the round-trip atoms are the problem.
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
# NO-OP: store back without modification
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# === Final O normalization: O *= 1/row_sum ===
|
||||
# === Correction epilog: one-way TMEM → reg → SMEM with normalize ===
|
||||
# Uses get_tmem_load_op + get_smem_store_op paired atoms.
|
||||
# NO TMEM round-trip — hand-constructed atoms corrupt data.
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
epi_corr_tile_size = 32 * 8 // self.o_dtype.width # 16 for BF16
|
||||
|
||||
tOtO_epi = cute.logical_divide(tOtO0, cute.make_layout((128, epi_corr_tile_size)))
|
||||
tmem_load_epi_atom = utils.sm100.get_tmem_load_op(
|
||||
self.pv_mma_tiler, self.c_layout, self.o_dtype, self.acc_dtype,
|
||||
(epi_tile[0], epi_corr_tile_size), self.use_2cta_instrs,
|
||||
)
|
||||
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout
|
||||
)
|
||||
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Standard epilogue: TMEM → SMEM → GMEM via TMA store.
|
||||
# O in TMEM is now scaled by 1/row_sum.
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
tiled_tmem_load_epi = tcgen05.make_tmem_copy(tmem_load_epi_atom, tOtO_epi[(None, None), 0])
|
||||
smem_store_epi_atom = utils.sm100.get_smem_store_op(
|
||||
self.c_layout, self.o_dtype, self.acc_dtype, tiled_tmem_load_epi,
|
||||
)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
|
||||
0, const_expr(lambda x: x), (0, 0, 0),
|
||||
acc_cons_st, acc_pipe, c_pipe,
|
||||
tiled_smem_store_epi = cute.make_tiled_copy_D(smem_store_epi_atom, tiled_tmem_load_epi)
|
||||
|
||||
tOsO = pv_thr.partition_C(sC)
|
||||
cO_epi = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO_epi = pv_thr.partition_C(cO_epi)
|
||||
tOsO_epi = cute.logical_divide(tOsO, cute.make_layout((128, epi_corr_tile_size)))
|
||||
tOcO_epi = cute.logical_divide(tOcO_epi, cute.make_layout((128, epi_corr_tile_size)))
|
||||
|
||||
thr_tmem_load_epi = tiled_tmem_load_epi.get_slice(sfw_idx)
|
||||
tTMEM_LOADtO_epi = thr_tmem_load_epi.partition_S(tOtO_epi[(None, None), None])
|
||||
tTMEM_LOADsO_epi = thr_tmem_load_epi.partition_D(tOsO_epi[(None, None), None])
|
||||
tTMEM_LOADcO_epi = thr_tmem_load_epi.partition_D(tOcO_epi[(None, None), None])
|
||||
|
||||
n_epi_corr_tiles = self.pv_mma_tiler[1] // epi_corr_tile_size
|
||||
for i in range(n_epi_corr_tiles):
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
tTMEM_LOADcO_epi[None, 0, 0, i].shape, self.acc_dtype
|
||||
)
|
||||
cute.copy(tiled_tmem_load_epi, tTMEM_LOADtO_epi[None, 0, 0, i], tTMrO)
|
||||
for j in range(cute.size(tTMrO)):
|
||||
tTMrO[j] = tTMrO[j] * inv_row_sum
|
||||
tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype)
|
||||
o_vec = tTMrO.load()
|
||||
tSMrO.store(o_vec.to(self.o_dtype))
|
||||
cute.copy(tiled_smem_store_epi, tSMrO, tTMEM_LOADsO_epi[None, 0, 0, i])
|
||||
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
|
||||
# TMA store SMEM → GMEM
|
||||
epi_bar = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
epi_bar.arrive_and_wait()
|
||||
cpasync.copy(tma_c, cute.select(sC, mode=[0, 1]), gC)
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
@@ -493,7 +429,6 @@ def test():
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
# Each n requires its own compiled kernel (s_k is compile-time).
|
||||
kernel = FmhaV3StageCMulti(s_k=n)
|
||||
print(f'n={n}: Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
@@ -518,4 +453,4 @@ def test():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
test()
|
||||
|
||||
Reference in New Issue
Block a user