fix: use attn_raw (not softmax'd) for unnorm computation

This commit is contained in:
2026-05-23 01:36:27 +00:00
parent 7becdaf739
commit d99a90ade5

View File

@@ -440,12 +440,12 @@ def test():
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
attn = torch.softmax(attn, dim=-1)
attn_raw = qf @ kf.T * scale
attn = torch.softmax(attn_raw, dim=-1)
ref = attn @ v.float()
# Also compute the unnormalized PV and row_sum for Python-side normalize
attn_unnorm = torch.exp(attn - attn.max(dim=-1, keepdim=True).values)
# Compute unnormalized softmax and row_sum for Python-side normalize
attn_unnorm = torch.exp(attn_raw - attn_raw.max(dim=-1, keepdim=True).values)
row_sum_unnorm = attn_unnorm.sum(dim=-1, keepdim=True)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))