diag: skip kernel normalize, do Python-side normalize to isolate TMEM round-trip issue

This commit is contained in:
2026-05-23 01:35:18 +00:00
parent 1698f01308
commit 1d397c8b67

View File

@@ -375,18 +375,8 @@ class FmhaV3StageCMulti:
cute.arch.fence_view_async_tmem_store()
# === Per-tile O rescale: O *= acc_scale for kt > 0 ===
# Uses 2D register tensor pattern (same as CUTLASS correction_rescale
# and our final normalize) to preserve data through TMEM round-trip.
if kt > 0:
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
@@ -395,10 +385,12 @@ class FmhaV3StageCMulti:
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[k] = tTMrO_i[k] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO)
cute.arch.fence_view_async_tmem_load()
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
tTMrO[k] = tTMrO[k] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
@@ -407,8 +399,16 @@ class FmhaV3StageCMulti:
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# DIAG: skip final normalize, just use epilogue_tma_store directly
# to test raw PV output
# === Final O normalization: O *= 1/row_sum ===
# DIAG: NO-OP TMEM round-trip test — load and store back without modifying
inv_row_sum = Float32(1.0) / row_sum
# SKIP the TMEM round-trip normalize entirely
# Just use epilogue_tma_store to read raw PV from TMEM
# The inv_row_sum normalization will be applied in Python for now
# Standard epilogue: TMEM → SMEM → GMEM via TMA store.
# O in TMEM is now scaled by 1/row_sum.
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
@@ -436,25 +436,17 @@ def test():
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
debug = torch.zeros(4, dtype=torch.float32, device='cuda') # [row_sum, row_max, inv_row_sum, 0]
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn_raw = qf @ kf.T * scale
attn = torch.softmax(attn_raw, dim=-1)
attn = qf @ kf.T * scale
attn = torch.softmax(attn, dim=-1)
ref = attn @ v.float()
# Expected stats for comparison
print(f' row_sum (should be 1.0): {attn.sum(dim=-1)[:4].tolist()}')
# Unnormalized softmax: exp(S - max)
S_max = attn_raw.max(dim=-1, keepdim=True).values
P_unnorm = torch.exp(attn_raw - S_max)
unnorm_pv = P_unnorm @ v.float()
unnorm_sum = P_unnorm.sum(dim=-1)
print(f' unnorm row_sum: {unnorm_sum[:4].tolist()}')
print(f' unnorm P@V[0,:4]: {unnorm_pv[0,:4].tolist()}')
print(f' kernel out[0,:4] should match unnorm P@V (no normalize)')
# Also compute the unnormalized PV and row_sum for Python-side normalize
attn_unnorm = torch.exp(attn - attn.max(dim=-1, keepdim=True).values)
row_sum_unnorm = attn_unnorm.sum(dim=-1, keepdim=True)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
@@ -473,17 +465,26 @@ def test():
torch.cuda.synchronize()
out = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
# Python-side normalize: out is raw P@V (unnormalized)
# Divide by row_sum to get the correct softmax output
out_normalized = out / row_sum_unnorm
cos_raw = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), (attn_unnorm @ v.float()).flatten().unsqueeze(0)
).item()
max_abs = (out - ref).abs().max().item()
cos_norm = torch.nn.functional.cosine_similarity(
out_normalized.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
max_abs = (out_normalized - ref).abs().max().item()
n_tiles = n // 128
print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): '
f'cos {cos:.6f} max_abs {max_abs:.4f} '
f'{"PASS" if cos >= 0.99 else "FAIL"}')
if cos < 0.99:
print(f' out[0,:4]={out[0,:4].tolist()}')
print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles):', flush=True)
print(f' Raw PV (no normalize) vs unnorm ref: cos {cos_raw:.6f}', flush=True)
print(f' After Python normalize vs softmax ref: cos {cos_norm:.6f} max_abs {max_abs:.4f} '
f'{"PASS" if cos_norm >= 0.99 else "FAIL"}', flush=True)
if cos_norm < 0.99:
print(f' out_normalized[0,:4]={out_normalized[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')
print(f' row_sum_unnorm[:4]={row_sum_unnorm[:4,0].tolist()}')
if __name__ == '__main__':