debug: add wide-search diagnostics for n=256 O rescale

This commit is contained in:
2026-05-23 01:02:33 +00:00
parent f026c1824c
commit 08d4af90ca

View File

@@ -356,6 +356,11 @@ class FmhaV3StageCMulti:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# DEBUG: print acc_scale at kt=1 (only from thread 0)
if kt == 1 and sfw_idx == 0:
cute.printf("O_RESCALE kt=1: old_max=%f new_max=%f acc_scale=%f row_sum=%f\n",
old_row_max, row_max_safe, acc_scale, row_sum)
# Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum,
# store BF16 P through the FP32-backed register bridge.
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
@@ -374,9 +379,23 @@ class FmhaV3StageCMulti:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
# DEBUG: print row_max and row_sum after each tile
if sfw_idx == 0:
cute.printf("SOFTMAX kt=%d: row_max=%f row_sum=%f\n", kt, row_max_safe, row_sum)
# === Per-tile O rescale: O *= acc_scale for kt > 0 ===
# Uses the SAME 2D register tensor pattern as final normalize
# to ensure paired atoms work correctly.
if kt > 0:
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
@@ -385,12 +404,15 @@ class FmhaV3StageCMulti:
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO)
cute.arch.fence_view_async_tmem_load()
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
tTMrO[k] = tTMrO[k] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
# DEBUG: print first element of O before and after rescale
if i == 0 and sfw_idx == 0:
cute.printf("O_RESCALE before: O[0]=%f acc_scale=%f\n", tTMrO_i[0], acc_scale)
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[k] = tTMrO_i[k] * acc_scale
if i == 0 and sfw_idx == 0:
cute.printf("O_RESCALE after: O[0]=%f\n", tTMrO_i[0])
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
@@ -447,7 +469,7 @@ class FmhaV3StageCMulti:
def test():
torch.manual_seed(42)
for n in [128, 256, 512, 1024]:
for n in [128, 256]:
torch.manual_seed(42)
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
@@ -460,9 +482,45 @@ def test():
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
# Per-tile reference breakdown
n_tiles = n // 128
for t in range(n_tiles):
tile_attn = attn[:, t*128:(t+1)*128]
tile_max = tile_attn.max(dim=-1, keepdim=True).values
print(f' Ref tile {t}: max_score={tile_max.max().item():.4f} '
f'mean_score={tile_attn.mean().item():.4f}', flush=True)
# Full softmax reference
attn = torch.softmax(attn, dim=-1)
ref = attn @ v.float()
# Also compute reference with online softmax to verify acc_scale
if n_tiles > 1:
# Simulate online softmax: process tiles sequentially
online_o = torch.zeros(m, hd, dtype=torch.float32)
online_row_max = torch.full((m, 1), float('-inf'))
online_row_sum = torch.zeros(m, 1)
for t in range(n_tiles):
tile_scores = qf @ kf[t*128:(t+1)*128].T * scale # (128, 128)
old_max = online_row_max.clone()
new_max = torch.max(old_max, tile_scores.max(dim=-1, keepdim=True).values)
acc_scale_ref = torch.exp(old_max - new_max)
acc_scale_ref[old_max == float('-inf')] = 0.0
online_o = online_o * acc_scale_ref
tile_softmax = torch.exp(tile_scores - new_max)
online_row_sum = online_row_sum * acc_scale_ref + tile_softmax.sum(dim=-1, keepdim=True)
online_o = online_o + tile_softmax @ v[t*128:(t+1)*128].float()
online_row_max = new_max
print(f' Online ref tile {t}: acc_scale_mean={acc_scale_ref[old_max != float("-inf")].mean().item():.6f} '
f'row_max_mean={online_row_max.mean().item():.4f} '
f'row_sum_mean={online_row_sum.mean().item():.4f}', flush=True)
online_o = online_o / online_row_sum
online_cos = torch.nn.functional.cosine_similarity(
online_o.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
print(f' Online softmax ref vs full softmax ref: cos {online_cos:.6f}', flush=True)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
@@ -484,13 +542,20 @@ def test():
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
max_abs = (out - ref).abs().max().item()
n_tiles = n // 128
print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): '
f'cos {cos:.6f} max_abs {max_abs:.4f} '
f'{"PASS" if cos >= 0.99 else "FAIL"}')
if cos < 0.99:
print(f' out[0,:4]={out[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')
# Detailed per-row error analysis
row_cos = torch.nn.functional.cosine_similarity(
out.unsqueeze(1), ref.unsqueeze(1), dim=-1)
worst_rows = row_cos.topk(5, largest=False)
print(f' Worst 5 rows: {[(i.item(), c.item()) for i, c in zip(worst_rows.indices, worst_rows.values)]}')
print(f' out[0,:8]={out[0,:8].tolist()}')
print(f' ref[0,:8]={ref[0,:8].tolist()}')
print(f' out range: [{out.min().item():.4f}, {out.max().item():.4f}]')
print(f' ref range: [{ref.min().item():.4f}, {ref.max().item():.4f}]')
print(f' mean abs error: {(out - ref).abs().mean().item():.6f}')
if __name__ == '__main__':