fix: use attn_raw (not softmax'd) for unnorm computation
This commit is contained in:
@@ -440,12 +440,12 @@ def test():
|
||||
qf = q[:, :, 0].float()
|
||||
kf = k[:, :, 0].float()
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
attn = qf @ kf.T * scale
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
attn_raw = qf @ kf.T * scale
|
||||
attn = torch.softmax(attn_raw, dim=-1)
|
||||
ref = attn @ v.float()
|
||||
|
||||
# Also compute the unnormalized PV and row_sum for Python-side normalize
|
||||
attn_unnorm = torch.exp(attn - attn.max(dim=-1, keepdim=True).values)
|
||||
# Compute unnormalized softmax and row_sum for Python-side normalize
|
||||
attn_unnorm = torch.exp(attn_raw - attn_raw.max(dim=-1, keepdim=True).values)
|
||||
row_sum_unnorm = attn_unnorm.sum(dim=-1, keepdim=True)
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
|
||||
Reference in New Issue
Block a user