From d99a90ade54b891c3ff8f351b4be0e078e0eb159 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 01:36:27 +0000 Subject: [PATCH] fix: use attn_raw (not softmax'd) for unnorm computation --- tests/unit/test_fmha_v3_stage_c.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index e2056dae..eacbe5f7 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -440,12 +440,12 @@ def test(): qf = q[:, :, 0].float() kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) - attn = qf @ kf.T * scale - attn = torch.softmax(attn, dim=-1) + attn_raw = qf @ kf.T * scale + attn = torch.softmax(attn_raw, dim=-1) ref = attn @ v.float() - # Also compute the unnormalized PV and row_sum for Python-side normalize - attn_unnorm = torch.exp(attn - attn.max(dim=-1, keepdim=True).values) + # Compute unnormalized softmax and row_sum for Python-side normalize + attn_unnorm = torch.exp(attn_raw - attn_raw.max(dim=-1, keepdim=True).values) row_sum_unnorm = attn_unnorm.sum(dim=-1, keepdim=True) mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))