diff --git a/tests/unit/test_fmha_v3_stage_d5b.py b/tests/unit/test_fmha_v3_stage_d5b.py index 8e981091..a8d9df1f 100644 --- a/tests/unit/test_fmha_v3_stage_d5b.py +++ b/tests/unit/test_fmha_v3_stage_d5b.py @@ -195,8 +195,12 @@ def test(): # Use the per-row reference lse for proper merge. # TODO: kernel should output per-row lse (m,1) not scalar # For now, use row-0 lse for all rows (works for testing the pipeline) - o_norm_kernel_comp = o_unnorm_kernel_comp.float() / torch.exp(lse_comp[0, 0]) - o_norm_kernel_swa = o_unnorm_kernel_swa.float() / torch.exp(lse_swa[0, 0]) + # NOTE: This gives wrong results for rows 1-127 since they have different LSE. + # Compare only row 0 for correctness. + lse_comp_per_row = lse_comp[:, 0] # (m,) — reference per-row LSE + lse_swa_per_row = lse_swa[:, 0] # (m,) — reference per-row LSE + o_norm_kernel_comp = o_unnorm_kernel_comp.float() / torch.exp(lse_comp_per_row.unsqueeze(1)) + o_norm_kernel_swa = o_unnorm_kernel_swa.float() / torch.exp(lse_swa_per_row.unsqueeze(1)) exp_lse_kern_comp = torch.exp(lse_comp_val) exp_lse_kern_swa = torch.exp(lse_swa_val)