DEBUG: O load+store NO-OP to verify TMEM copy correctness

This commit is contained in:
2026-05-22 19:12:43 +00:00
parent c2fe1bffb5
commit 5a0b575bfe

View File

@@ -310,22 +310,20 @@ class FmhaV3RealSoftmax:
si_handle.release()
softmax_done_bar.arrive()
# Final O normalization: O = O / row_sum
if row_sum != Float32(0.0):
inv_row_sum = Float32(1.0) / row_sum
n_corr = HEAD_DIM // corr_tile_size
for ci in range(n_corr):
tTMrO_fn = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
tTMEM_LOAD_OtO_ci = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
)
tTMEM_STORE_OtO_ci = cute.make_tensor(
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_fn)
for j in cutlass.range(cute.size(tTMrO_fn), vectorize=True):
tTMrO_fn[j] = tTMrO_fn[j] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO_fn, tTMEM_STORE_OtO_ci)
# Final O normalization: load, NO-OP write (debug copy)
# If this matches unnormalized output, the TMEM copy is correct.
n_corr = HEAD_DIM // corr_tile_size
for ci in range(n_corr):
tTMrO_fn = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
tTMEM_LOAD_OtO_ci = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
)
tTMEM_STORE_OtO_ci = cute.make_tensor(
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_fn)
# NO-OP: write back without modification
cute.copy(tiled_tmem_store_o, tTMrO_fn, tTMEM_STORE_OtO_ci)
# Epilogue: TMEM -> SMEM -> GMEM via TMA store
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
@@ -348,13 +346,14 @@ def test():
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# Reference: proper softmax
# Reference: unnormalized (just softmax numerators @ V, no 1/row_sum)
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
attn = torch.softmax(attn, dim=-1)
ref = attn @ v.float()
attn_max = attn.max(dim=-1, keepdim=True)[0]
attn_unnorm = torch.exp(attn - attn_max)
ref = attn_unnorm @ v.float()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))