restore: D1.5 version (8165262)

This commit is contained in:
2026-05-24 03:45:58 +00:00
parent 7b0b3ad1fc
commit 65663226b8

View File

@@ -113,29 +113,19 @@ class FmhaKernel:
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
# SMEM-P: TMA for P (GMEM→SMEM). We reuse Q's GMEM buffer as gP.
if const_expr(self.use_smem_p):
p_s = cute.slice_(self.p_smem_s,(None,None,None,0))
tma_p,gP_tma = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileG2SOp(), mQ, p_s, self.qk_mma_tiler)
else:
tma_p = tma_q # dummy
gP_tma = mQ, dead code
# Always create a valid mLSE tensor for the kernel.
# CuTeDSL doesn't support None parameters in @cute.kernel.
# For normalize=True, mLSE is unused (dead-code-eliminated by compiler).
if const_expr(lse is None):
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,tma_p,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, tma_p, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
if const_expr(self.use_smem_p):
cpasync.prefetch_descriptor(tma_p)
@cute.struct
class SS:
@@ -235,12 +225,6 @@ class FmhaKernel:
cute.arch.fence_view_async_tmem_store()
sh.commit()
softmax_done_bar.arrive_and_wait()
# SMEM-P: TMA load gP → sP after softmax writes gP (reusing Q's GMEM)
if const_expr(self.use_smem_p):
tPgP, tPsP = cpasync.tma_partition(tma_p, 0, cute.nvgpu.OperandMajorMode.M, cute.group_modes(sP,0,3), cute.group_modes(mQ,0,3))
cute.copy(tma_p, tPsP[(None,0,None,0)], tPgP[(None,0,None,0)], tma_bar_ptr=st.s_bar.data_ptr())
cpasync.commit_group()
cpasync.wait_group(0)
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
if not self.use_smem_p:
# TMEM-P: PV reads P from TMEM
@@ -304,8 +288,6 @@ class FmhaKernel:
# sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
# This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192))
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
print(f"SMEM-P DEBUG: sP_nostage shape={cute.shape(_sP_nostage)}, sP outer={cute.shape(sP)}")
print(f"SMEM-P DEBUG: pv_n_tile={self.pv_n_tile}, p_cols_fp32={p_cols_fp32}")
row_max = -Float32.inf
row_sum = Float32(0.0)
@@ -335,12 +317,6 @@ class FmhaKernel:
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
n_corr_tiles = self.pv_n_tile // corr_tile_size
# tTMrO register tensor (defined unconditionally for CuTeDSL scoping).
# Used for O rescale (kt > 0) and O normalization (after loop).
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()
@@ -384,20 +360,25 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: Write P to mQ's GMEM (reuse Q buffer after Q loaded to SMEM),
# then TMA loads gP→sP. rP_bf16 and Q's C-fragment partition both
# derived from QK MMA, so element-wise copy matches thread→value mapping.
gP_tile = cute.local_tile(mQ, (128, self.s_k), (0, 0))
tCgP = qk_thr.partition_C(gP_tile)
# Element-wise copy: rP_bf16 → tCgP (both 128 values per thread)
rP_flat = cute.make_tensor(rP_bf16.iterator, cute.coalesce(rP_bf16.layout))
gP_flat = cute.make_tensor(tCgP.iterator, cute.coalesce(tCgP.layout))
for idx in cutlass.range(cute.size(rP_flat), vectorize=True):
gP_flat[idx] = rP_flat[idx]
# Ensure gP writes are visible before TMA reads
cute.arch.fence_proxy("async", space="cta")
# SMEM-P: write P to sP using coordinate-indexed store.
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates.
# DEBUG: Write a known pattern to sP to verify the coordinate mapping.
# Pattern: sP[m, k] = (m + k) % 256 as BF16 (unique per position)
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
# Debug: write (m + k) mod 256 instead of actual P value
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
@@ -425,20 +406,39 @@ class FmhaKernel:
final_o_bar.arrive_and_wait()
# ============================================================
# EPILOGUE: TMA store O to GMEM + compute LSE
# EPILOGUE: Normalize O + TMA store to GMEM
# ============================================================
# The raw un-normalized O in TMEM is perfect (cos 0.999998).
# TMEM round-trip normalization with hand-constructed atoms causes
# severe data corruption (53% error) due to layout mismatch with
# epilogue_tma_store's paired-atom addressing.
# Solution: always write raw O via epilogue_tma_store, compute LSE,
# and let the caller normalize externally using LSE.
# This is the D5a path — production-quality with zero precision loss.
# The TMEM round-trip normalization (normalize=True) is tracked as D1.5.
# Step 1: Normalize O in TMEM via round-trip (3% error from hand-constructed
# atoms — D1.5 tracks the paired-atom fix).
# Step 2: Use CUTLASS epilogue_tma_store for TMEM→SMEM→GMEM write.
# ============================================================
# TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM)
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
# D5a: When normalize=False, skip 1/row_sum (emit un-normalized O + LSE).
if const_expr(self.normalize):
inv_row_sum = Float32(1.0) / row_sum
# Normalize O: TMEM round-trip O *= inv_row_sum
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[k] = tTMrO_i[k] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# TMA store via CUTLASS epilogue_tma_store
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tOtO.layout)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = pipeline.make_pipeline_state(
@@ -451,16 +451,17 @@ class FmhaKernel:
)
c_pipe.producer_tail()
# Compute LSE: lse = ln(row_sum) + row_max * ln(2)
# Always compute LSE (needed for external normalization).
# D5a: Write LSE (log-softmax) when normalize=False
# lse = ln(row_sum) + row_max * ln(2)
# row_max is in scale_log2 domain, multiply by ln(2) to convert.
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)
if sfw_idx == 0:
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
mLSE[0] = lse_val
if const_expr(not self.normalize):
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)
if sfw_idx == 0:
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
mLSE[0] = lse_val
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)