diag: skip kernel normalize, do Python-side normalize to isolate TMEM round-trip issue
This commit is contained in:
@@ -375,18 +375,8 @@ class FmhaV3StageCMulti:
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# === Per-tile O rescale: O *= acc_scale for kt > 0 ===
|
||||
# Uses 2D register tensor pattern (same as CUTLASS correction_rescale
|
||||
# and our final normalize) to preserve data through TMEM round-trip.
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
@@ -395,10 +385,12 @@ class FmhaV3StageCMulti:
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[k] = tTMrO_i[k] * acc_scale
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
|
||||
tTMrO[k] = tTMrO[k] * acc_scale
|
||||
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
si_handle.release()
|
||||
@@ -407,8 +399,16 @@ class FmhaV3StageCMulti:
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# DIAG: skip final normalize, just use epilogue_tma_store directly
|
||||
# to test raw PV output
|
||||
# === Final O normalization: O *= 1/row_sum ===
|
||||
# DIAG: NO-OP TMEM round-trip test — load and store back without modifying
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
# SKIP the TMEM round-trip normalize entirely
|
||||
# Just use epilogue_tma_store to read raw PV from TMEM
|
||||
# The inv_row_sum normalization will be applied in Python for now
|
||||
|
||||
# Standard epilogue: TMEM → SMEM → GMEM via TMA store.
|
||||
# O in TMEM is now scaled by 1/row_sum.
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
@@ -436,25 +436,17 @@ def test():
|
||||
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v_kernel = v.unsqueeze(-1)
|
||||
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
debug = torch.zeros(4, dtype=torch.float32, device='cuda') # [row_sum, row_max, inv_row_sum, 0]
|
||||
|
||||
qf = q[:, :, 0].float()
|
||||
kf = k[:, :, 0].float()
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
attn_raw = qf @ kf.T * scale
|
||||
attn = torch.softmax(attn_raw, dim=-1)
|
||||
attn = qf @ kf.T * scale
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
ref = attn @ v.float()
|
||||
|
||||
# Expected stats for comparison
|
||||
print(f' row_sum (should be 1.0): {attn.sum(dim=-1)[:4].tolist()}')
|
||||
# Unnormalized softmax: exp(S - max)
|
||||
S_max = attn_raw.max(dim=-1, keepdim=True).values
|
||||
P_unnorm = torch.exp(attn_raw - S_max)
|
||||
unnorm_pv = P_unnorm @ v.float()
|
||||
unnorm_sum = P_unnorm.sum(dim=-1)
|
||||
print(f' unnorm row_sum: {unnorm_sum[:4].tolist()}')
|
||||
print(f' unnorm P@V[0,:4]: {unnorm_pv[0,:4].tolist()}')
|
||||
print(f' kernel out[0,:4] should match unnorm P@V (no normalize)')
|
||||
# Also compute the unnormalized PV and row_sum for Python-side normalize
|
||||
attn_unnorm = torch.exp(attn - attn.max(dim=-1, keepdim=True).values)
|
||||
row_sum_unnorm = attn_unnorm.sum(dim=-1, keepdim=True)
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
@@ -473,17 +465,26 @@ def test():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out = c[:, :, 0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
|
||||
# Python-side normalize: out is raw P@V (unnormalized)
|
||||
# Divide by row_sum to get the correct softmax output
|
||||
out_normalized = out / row_sum_unnorm
|
||||
cos_raw = torch.nn.functional.cosine_similarity(
|
||||
out.flatten().unsqueeze(0), (attn_unnorm @ v.float()).flatten().unsqueeze(0)
|
||||
).item()
|
||||
max_abs = (out - ref).abs().max().item()
|
||||
cos_norm = torch.nn.functional.cosine_similarity(
|
||||
out_normalized.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
).item()
|
||||
max_abs = (out_normalized - ref).abs().max().item()
|
||||
n_tiles = n // 128
|
||||
print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): '
|
||||
f'cos {cos:.6f} max_abs {max_abs:.4f} '
|
||||
f'{"PASS" if cos >= 0.99 else "FAIL"}')
|
||||
if cos < 0.99:
|
||||
print(f' out[0,:4]={out[0,:4].tolist()}')
|
||||
print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles):', flush=True)
|
||||
print(f' Raw PV (no normalize) vs unnorm ref: cos {cos_raw:.6f}', flush=True)
|
||||
print(f' After Python normalize vs softmax ref: cos {cos_norm:.6f} max_abs {max_abs:.4f} '
|
||||
f'{"PASS" if cos_norm >= 0.99 else "FAIL"}', flush=True)
|
||||
if cos_norm < 0.99:
|
||||
print(f' out_normalized[0,:4]={out_normalized[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref[0,:4].tolist()}')
|
||||
print(f' row_sum_unnorm[:4]={row_sum_unnorm[:4,0].tolist()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user