diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index f456d8c0..8dde4630 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -400,7 +400,8 @@ class FmhaV3StageCMulti: final_o_bar.arrive_and_wait() # === Final O normalization: O *= 1/row_sum === - inv_row_sum = Float32(1.0) / row_sum + # DIAG: use 1.0 instead of 1/row_sum to test raw PV output + inv_row_sum = Float32(1.0) tTMrO = cute.make_rmem_tensor( (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype @@ -447,7 +448,7 @@ class FmhaV3StageCMulti: def test(): torch.manual_seed(42) - for n in [128, 256]: + for n in [128]: torch.manual_seed(42) m, hd = 128, HEAD_DIM q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')