debug: add wide-search diagnostics for n=256 O rescale
This commit is contained in:
@@ -356,6 +356,11 @@ class FmhaV3StageCMulti:
|
||||
acc_scale = Float32(0.0)
|
||||
row_sum *= acc_scale
|
||||
|
||||
# DEBUG: print acc_scale at kt=1 (only from thread 0)
|
||||
if kt == 1 and sfw_idx == 0:
|
||||
cute.printf("O_RESCALE kt=1: old_max=%f new_max=%f acc_scale=%f row_sum=%f\n",
|
||||
old_row_max, row_max_safe, acc_scale, row_sum)
|
||||
|
||||
# Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum,
|
||||
# store BF16 P through the FP32-backed register bridge.
|
||||
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
|
||||
@@ -374,9 +379,23 @@ class FmhaV3StageCMulti:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# DEBUG: print row_max and row_sum after each tile
|
||||
if sfw_idx == 0:
|
||||
cute.printf("SOFTMAX kt=%d: row_max=%f row_sum=%f\n", kt, row_max_safe, row_sum)
|
||||
|
||||
# === Per-tile O rescale: O *= acc_scale for kt > 0 ===
|
||||
# Uses the SAME 2D register tensor pattern as final normalize
|
||||
# to ensure paired atoms work correctly.
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
@@ -385,12 +404,15 @@ class FmhaV3StageCMulti:
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
|
||||
tTMrO[k] = tTMrO[k] * acc_scale
|
||||
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
# DEBUG: print first element of O before and after rescale
|
||||
if i == 0 and sfw_idx == 0:
|
||||
cute.printf("O_RESCALE before: O[0]=%f acc_scale=%f\n", tTMrO_i[0], acc_scale)
|
||||
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[k] = tTMrO_i[k] * acc_scale
|
||||
if i == 0 and sfw_idx == 0:
|
||||
cute.printf("O_RESCALE after: O[0]=%f\n", tTMrO_i[0])
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
si_handle.release()
|
||||
@@ -447,7 +469,7 @@ class FmhaV3StageCMulti:
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
for n in [128, 256, 512, 1024]:
|
||||
for n in [128, 256]:
|
||||
torch.manual_seed(42)
|
||||
m, hd = 128, HEAD_DIM
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
@@ -460,9 +482,45 @@ def test():
|
||||
kf = k[:, :, 0].float()
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
attn = qf @ kf.T * scale
|
||||
|
||||
# Per-tile reference breakdown
|
||||
n_tiles = n // 128
|
||||
for t in range(n_tiles):
|
||||
tile_attn = attn[:, t*128:(t+1)*128]
|
||||
tile_max = tile_attn.max(dim=-1, keepdim=True).values
|
||||
print(f' Ref tile {t}: max_score={tile_max.max().item():.4f} '
|
||||
f'mean_score={tile_attn.mean().item():.4f}', flush=True)
|
||||
|
||||
# Full softmax reference
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
ref = attn @ v.float()
|
||||
|
||||
# Also compute reference with online softmax to verify acc_scale
|
||||
if n_tiles > 1:
|
||||
# Simulate online softmax: process tiles sequentially
|
||||
online_o = torch.zeros(m, hd, dtype=torch.float32)
|
||||
online_row_max = torch.full((m, 1), float('-inf'))
|
||||
online_row_sum = torch.zeros(m, 1)
|
||||
for t in range(n_tiles):
|
||||
tile_scores = qf @ kf[t*128:(t+1)*128].T * scale # (128, 128)
|
||||
old_max = online_row_max.clone()
|
||||
new_max = torch.max(old_max, tile_scores.max(dim=-1, keepdim=True).values)
|
||||
acc_scale_ref = torch.exp(old_max - new_max)
|
||||
acc_scale_ref[old_max == float('-inf')] = 0.0
|
||||
online_o = online_o * acc_scale_ref
|
||||
tile_softmax = torch.exp(tile_scores - new_max)
|
||||
online_row_sum = online_row_sum * acc_scale_ref + tile_softmax.sum(dim=-1, keepdim=True)
|
||||
online_o = online_o + tile_softmax @ v[t*128:(t+1)*128].float()
|
||||
online_row_max = new_max
|
||||
print(f' Online ref tile {t}: acc_scale_mean={acc_scale_ref[old_max != float("-inf")].mean().item():.6f} '
|
||||
f'row_max_mean={online_row_max.mean().item():.4f} '
|
||||
f'row_sum_mean={online_row_sum.mean().item():.4f}', flush=True)
|
||||
online_o = online_o / online_row_sum
|
||||
online_cos = torch.nn.functional.cosine_similarity(
|
||||
online_o.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
).item()
|
||||
print(f' Online softmax ref vs full softmax ref: cos {online_cos:.6f}', flush=True)
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
@@ -484,13 +542,20 @@ def test():
|
||||
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
).item()
|
||||
max_abs = (out - ref).abs().max().item()
|
||||
n_tiles = n // 128
|
||||
print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): '
|
||||
f'cos {cos:.6f} max_abs {max_abs:.4f} '
|
||||
f'{"PASS" if cos >= 0.99 else "FAIL"}')
|
||||
if cos < 0.99:
|
||||
print(f' out[0,:4]={out[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref[0,:4].tolist()}')
|
||||
# Detailed per-row error analysis
|
||||
row_cos = torch.nn.functional.cosine_similarity(
|
||||
out.unsqueeze(1), ref.unsqueeze(1), dim=-1)
|
||||
worst_rows = row_cos.topk(5, largest=False)
|
||||
print(f' Worst 5 rows: {[(i.item(), c.item()) for i, c in zip(worst_rows.indices, worst_rows.values)]}')
|
||||
print(f' out[0,:8]={out[0,:8].tolist()}')
|
||||
print(f' ref[0,:8]={ref[0,:8].tolist()}')
|
||||
print(f' out range: [{out.min().item():.4f}, {out.max().item():.4f}]')
|
||||
print(f' ref range: [{ref.min().item():.4f}, {ref.max().item():.4f}]')
|
||||
print(f' mean abs error: {(out - ref).abs().mean().item():.6f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user