From d511ebe387adf85f3e760f801cd58feb6ce4d197 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 00:34:38 +0000 Subject: [PATCH] Debug: add row_sum/inv_row_sum printf at final normalize --- tests/unit/test_fmha_v3_stage_c.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 95d9a4fc..be944adb 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -332,6 +332,11 @@ class FmhaV3StageCMulti: cute.arch.fence_view_async_tmem_load() # Pass 1: update row_max (in log2-domain, fused with scale). + # Compute O rescale factor and update row_sum. + # At kt=0, old_row_max is -inf, so acc_scale = 0 — but + # row_sum starts at 0 too, so row_sum *= 0 is fine (0*0=0). + # The O rescale (O *= acc_scale) must be skipped at kt=0 + # because it would zero out the first tile's contribution. old_row_max = row_max frg_cnt = 4 frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt @@ -398,6 +403,7 @@ class FmhaV3StageCMulti: # === Final O normalization: O *= 1/row_sum === inv_row_sum = Float32(1.0) / row_sum + cute.printf("FINAL row_sum=%f inv_row_sum=%f\n", row_sum, inv_row_sum) tTMrO = cute.make_rmem_tensor( (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype