restore: D1.5 version (8165262)
This commit is contained in:
@@ -113,29 +113,19 @@ class FmhaKernel:
|
||||
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
|
||||
epi_s = cute.select(self.c_smem_s,mode=[0,1])
|
||||
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
|
||||
|
||||
# SMEM-P: TMA for P (GMEM→SMEM). We reuse Q's GMEM buffer as gP.
|
||||
if const_expr(self.use_smem_p):
|
||||
p_s = cute.slice_(self.p_smem_s,(None,None,None,0))
|
||||
tma_p,gP_tma = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileG2SOp(), mQ, p_s, self.qk_mma_tiler)
|
||||
else:
|
||||
tma_p = tma_q # dummy
|
||||
gP_tma = mQ, dead code
|
||||
# Always create a valid mLSE tensor for the kernel.
|
||||
# CuTeDSL doesn't support None parameters in @cute.kernel.
|
||||
# For normalize=True, mLSE is unused (dead-code-eliminated by compiler).
|
||||
if const_expr(lse is None):
|
||||
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,tma_p,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, tma_p, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx,_,_ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
if const_expr(self.use_smem_p):
|
||||
cpasync.prefetch_descriptor(tma_p)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
@@ -235,12 +225,6 @@ class FmhaKernel:
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
sh.commit()
|
||||
softmax_done_bar.arrive_and_wait()
|
||||
# SMEM-P: TMA load gP → sP after softmax writes gP (reusing Q's GMEM)
|
||||
if const_expr(self.use_smem_p):
|
||||
tPgP, tPsP = cpasync.tma_partition(tma_p, 0, cute.nvgpu.OperandMajorMode.M, cute.group_modes(sP,0,3), cute.group_modes(mQ,0,3))
|
||||
cute.copy(tma_p, tPsP[(None,0,None,0)], tPgP[(None,0,None,0)], tma_bar_ptr=st.s_bar.data_ptr())
|
||||
cpasync.commit_group()
|
||||
cpasync.wait_group(0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
|
||||
if not self.use_smem_p:
|
||||
# TMEM-P: PV reads P from TMEM
|
||||
@@ -304,8 +288,6 @@ class FmhaKernel:
|
||||
# sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
|
||||
# This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192))
|
||||
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
|
||||
print(f"SMEM-P DEBUG: sP_nostage shape={cute.shape(_sP_nostage)}, sP outer={cute.shape(sP)}")
|
||||
print(f"SMEM-P DEBUG: pv_n_tile={self.pv_n_tile}, p_cols_fp32={p_cols_fp32}")
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
@@ -335,12 +317,6 @@ class FmhaKernel:
|
||||
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
|
||||
n_corr_tiles = self.pv_n_tile // corr_tile_size
|
||||
|
||||
# tTMrO register tensor (defined unconditionally for CuTeDSL scoping).
|
||||
# Used for O rescale (kt > 0) and O normalization (after loop).
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
|
||||
for kt in range(self.n_kv_tiles):
|
||||
si_handle = s_cons.wait_and_advance()
|
||||
|
||||
@@ -384,20 +360,25 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: Write P to mQ's GMEM (reuse Q buffer after Q loaded to SMEM),
|
||||
# then TMA loads gP→sP. rP_bf16 and Q's C-fragment partition both
|
||||
# derived from QK MMA, so element-wise copy matches thread→value mapping.
|
||||
gP_tile = cute.local_tile(mQ, (128, self.s_k), (0, 0))
|
||||
tCgP = qk_thr.partition_C(gP_tile)
|
||||
# Element-wise copy: rP_bf16 → tCgP (both 128 values per thread)
|
||||
rP_flat = cute.make_tensor(rP_bf16.iterator, cute.coalesce(rP_bf16.layout))
|
||||
gP_flat = cute.make_tensor(tCgP.iterator, cute.coalesce(tCgP.layout))
|
||||
for idx in cutlass.range(cute.size(rP_flat), vectorize=True):
|
||||
gP_flat[idx] = rP_flat[idx]
|
||||
# Ensure gP writes are visible before TMA reads
|
||||
cute.arch.fence_proxy("async", space="cta")
|
||||
# SMEM-P: write P to sP using coordinate-indexed store.
|
||||
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates.
|
||||
# DEBUG: Write a known pattern to sP to verify the coordinate mapping.
|
||||
# Pattern: sP[m, k] = (m + k) % 256 as BF16 (unique per position)
|
||||
for j0 in range(32):
|
||||
for j1 in range(4):
|
||||
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
|
||||
m_coord = coord[0]
|
||||
k_coord = coord[1]
|
||||
k0 = k_coord % 16
|
||||
k1 = (k_coord // 16) % 4
|
||||
k2 = k_coord // 64
|
||||
# Debug: write (m + k) mod 256 instead of actual P value
|
||||
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
@@ -425,20 +406,39 @@ class FmhaKernel:
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# ============================================================
|
||||
# EPILOGUE: TMA store O to GMEM + compute LSE
|
||||
# EPILOGUE: Normalize O + TMA store to GMEM
|
||||
# ============================================================
|
||||
# The raw un-normalized O in TMEM is perfect (cos 0.999998).
|
||||
# TMEM round-trip normalization with hand-constructed atoms causes
|
||||
# severe data corruption (53% error) due to layout mismatch with
|
||||
# epilogue_tma_store's paired-atom addressing.
|
||||
# Solution: always write raw O via epilogue_tma_store, compute LSE,
|
||||
# and let the caller normalize externally using LSE.
|
||||
# This is the D5a path — production-quality with zero precision loss.
|
||||
# The TMEM round-trip normalization (normalize=True) is tracked as D1.5.
|
||||
# Step 1: Normalize O in TMEM via round-trip (3% error from hand-constructed
|
||||
# atoms — D1.5 tracks the paired-atom fix).
|
||||
# Step 2: Use CUTLASS epilogue_tma_store for TMEM→SMEM→GMEM write.
|
||||
# ============================================================
|
||||
|
||||
# TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM)
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
# D5a: When normalize=False, skip 1/row_sum (emit un-normalized O + LSE).
|
||||
if const_expr(self.normalize):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
# Normalize O: TMEM round-trip O *= inv_row_sum
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[k] = tTMrO_i[k] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# TMA store via CUTLASS epilogue_tma_store
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tOtO.layout)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
@@ -451,16 +451,17 @@ class FmhaKernel:
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
|
||||
# Compute LSE: lse = ln(row_sum) + row_max * ln(2)
|
||||
# Always compute LSE (needed for external normalization).
|
||||
# D5a: Write LSE (log-softmax) when normalize=False
|
||||
# lse = ln(row_sum) + row_max * ln(2)
|
||||
# row_max is in scale_log2 domain, multiply by ln(2) to convert.
|
||||
_row_max_safe = row_max
|
||||
if row_max == -cutlass.Float32.inf:
|
||||
_row_max_safe = Float32(0.0)
|
||||
if sfw_idx == 0:
|
||||
_ln2 = Float32(0.6931471805599453) # ln(2)
|
||||
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
|
||||
mLSE[0] = lse_val
|
||||
if const_expr(not self.normalize):
|
||||
_row_max_safe = row_max
|
||||
if row_max == -cutlass.Float32.inf:
|
||||
_row_max_safe = Float32(0.0)
|
||||
if sfw_idx == 0:
|
||||
_ln2 = Float32(0.6931471805599453) # ln(2)
|
||||
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
|
||||
mLSE[0] = lse_val
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
Reference in New Issue
Block a user