From 690fd77e6c02ab93f814aafb6ce5391cdbd09f99 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 01:17:14 +0000 Subject: [PATCH] diag: inv_row_sum=1.0 to test raw PV, n=128 only --- tests/unit/test_fmha_v3_stage_c.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index f456d8c0..8dde4630 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -400,7 +400,8 @@ class FmhaV3StageCMulti: final_o_bar.arrive_and_wait() # === Final O normalization: O *= 1/row_sum === - inv_row_sum = Float32(1.0) / row_sum + # DIAG: use 1.0 instead of 1/row_sum to test raw PV output + inv_row_sum = Float32(1.0) tTMrO = cute.make_rmem_tensor( (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype @@ -447,7 +448,7 @@ class FmhaV3StageCMulti: def test(): torch.manual_seed(42) - for n in [128, 256]: + for n in [128]: torch.manual_seed(42) m, hd = 128, HEAD_DIM q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')