diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 5f01bb75..4a6ec02e 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -356,6 +356,11 @@ class FmhaV3StageCMulti: acc_scale = Float32(0.0) row_sum *= acc_scale + # DEBUG: print acc_scale at kt=1 (only from thread 0) + if kt == 1 and sfw_idx == 0: + cute.printf("O_RESCALE kt=1: old_max=%f new_max=%f acc_scale=%f row_sum=%f\n", + old_row_max, row_max_safe, acc_scale, row_sum) + # Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum, # store BF16 P through the FP32-backed register bridge. rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) @@ -374,9 +379,23 @@ class FmhaV3StageCMulti: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() + # DEBUG: print row_max and row_sum after each tile + if sfw_idx == 0: + cute.printf("SOFTMAX kt=%d: row_max=%f row_sum=%f\n", kt, row_max_safe, row_sum) + # === Per-tile O rescale: O *= acc_scale for kt > 0 === + # Uses the SAME 2D register tensor pattern as final normalize + # to ensure paired atoms work correctly. if kt > 0: + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) for i in range(n_corr_tiles): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) tTMEM_LOADtO_i = cute.make_tensor( tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout, @@ -385,12 +404,15 @@ class FmhaV3StageCMulti: tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout, ) - tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO) - cute.arch.fence_view_async_tmem_load() - for k in cutlass.range(cute.size(tTMrO), vectorize=True): - tTMrO[k] = tTMrO[k] * acc_scale - cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + # DEBUG: print first element of O before and after rescale + if i == 0 and sfw_idx == 0: + cute.printf("O_RESCALE before: O[0]=%f acc_scale=%f\n", tTMrO_i[0], acc_scale) + for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[k] = tTMrO_i[k] * acc_scale + if i == 0 and sfw_idx == 0: + cute.printf("O_RESCALE after: O[0]=%f\n", tTMrO_i[0]) + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) cute.arch.fence_view_async_tmem_store() si_handle.release() @@ -447,7 +469,7 @@ class FmhaV3StageCMulti: def test(): torch.manual_seed(42) - for n in [128, 256, 512, 1024]: + for n in [128, 256]: torch.manual_seed(42) m, hd = 128, HEAD_DIM q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') @@ -460,9 +482,45 @@ def test(): kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) attn = qf @ kf.T * scale + + # Per-tile reference breakdown + n_tiles = n // 128 + for t in range(n_tiles): + tile_attn = attn[:, t*128:(t+1)*128] + tile_max = tile_attn.max(dim=-1, keepdim=True).values + print(f' Ref tile {t}: max_score={tile_max.max().item():.4f} ' + f'mean_score={tile_attn.mean().item():.4f}', flush=True) + + # Full softmax reference attn = torch.softmax(attn, dim=-1) ref = attn @ v.float() + # Also compute reference with online softmax to verify acc_scale + if n_tiles > 1: + # Simulate online softmax: process tiles sequentially + online_o = torch.zeros(m, hd, dtype=torch.float32) + online_row_max = torch.full((m, 1), float('-inf')) + online_row_sum = torch.zeros(m, 1) + for t in range(n_tiles): + tile_scores = qf @ kf[t*128:(t+1)*128].T * scale # (128, 128) + old_max = online_row_max.clone() + new_max = torch.max(old_max, tile_scores.max(dim=-1, keepdim=True).values) + acc_scale_ref = torch.exp(old_max - new_max) + acc_scale_ref[old_max == float('-inf')] = 0.0 + online_o = online_o * acc_scale_ref + tile_softmax = torch.exp(tile_scores - new_max) + online_row_sum = online_row_sum * acc_scale_ref + tile_softmax.sum(dim=-1, keepdim=True) + online_o = online_o + tile_softmax @ v[t*128:(t+1)*128].float() + online_row_max = new_max + print(f' Online ref tile {t}: acc_scale_mean={acc_scale_ref[old_max != float("-inf")].mean().item():.6f} ' + f'row_max_mean={online_row_max.mean().item():.4f} ' + f'row_sum_mean={online_row_sum.mean().item():.4f}', flush=True) + online_o = online_o / online_row_sum + online_cos = torch.nn.functional.cosine_similarity( + online_o.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) + ).item() + print(f' Online softmax ref vs full softmax ref: cos {online_cos:.6f}', flush=True) + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) @@ -484,13 +542,20 @@ def test(): out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) ).item() max_abs = (out - ref).abs().max().item() - n_tiles = n // 128 print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): ' f'cos {cos:.6f} max_abs {max_abs:.4f} ' f'{"PASS" if cos >= 0.99 else "FAIL"}') if cos < 0.99: - print(f' out[0,:4]={out[0,:4].tolist()}') - print(f' ref[0,:4]={ref[0,:4].tolist()}') + # Detailed per-row error analysis + row_cos = torch.nn.functional.cosine_similarity( + out.unsqueeze(1), ref.unsqueeze(1), dim=-1) + worst_rows = row_cos.topk(5, largest=False) + print(f' Worst 5 rows: {[(i.item(), c.item()) for i, c in zip(worst_rows.indices, worst_rows.values)]}') + print(f' out[0,:8]={out[0,:8].tolist()}') + print(f' ref[0,:8]={ref[0,:8].tolist()}') + print(f' out range: [{out.min().item():.4f}, {out.max().item():.4f}]') + print(f' ref range: [{ref.min().item():.4f}, {ref.max().item():.4f}]') + print(f' mean abs error: {(out - ref).abs().mean().item():.6f}') if __name__ == '__main__':