test: correct sink bias reference — denominator-only, no V contribution

This commit is contained in:
2026-06-03 04:57:37 +00:00
parent 9574a9dc2e
commit d518fcb82a

View File

@@ -439,32 +439,30 @@ def main():
continue
o_prod = o_prod_4d.squeeze(0) # (n_h, T, hd)
# 7. Reference: dequantize mixed KV to BF16, run SDPA WITH sink bias
# 7. Reference: dequantize mixed KV to BF16, run reference with sink bias
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
kv_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, hd)
k_4d = kv_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) # (1, 1, N, hd)
v_4d = k_4d.clone()
if sink_bias is not None:
# Reference with sink bias: compute attention scores manually and add sink as logit bias
# DSV4 sink merge = single softmax over [S_comp, S_swa + attn_sink]
# The sink bias is added as a logit to the LAST position (the sink position)
# In our SWA-only case, the sink is the last SWA entry.
# Formula: attn_weights = softmax(Q @ K^T * scale + sink_bias_on_last_pos)
# For per-head sink: sink_bias shape is (n_h,)
# DSV4 sink is denominator-only: O = sum(P*V) / (sum(P) + exp(sb))
# where P = softmax(QK*scale). The sink has NO V contribution.
# Reference: compute O_no_sink, then scale by correction factor.
q_ref = q_4d.float() # (1, H, T, hd)
k_ref = k_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd)
v_ref = v_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd)
scores = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale # (1, H, T, N)
# Add sink bias to the last position
# sink_bias: (n_h,) → (1, n_h, 1, 1) broadcast to (1, n_h, T, N) on last N dim
# The production kernel adds sink_bias / scale_softmax to the last SWA position
# Actually, the sink bias is added as-is to the logit at the sink position
sb_4d = sink_bias.reshape(1, n_h, 1, 1) # (1, H, 1, 1)
# The sink position is the LAST position in the KV sequence
scores[:, :, :, -1:] = scores[:, :, :, -1:] + sb_4d
# Softmax
attn_weights = F.softmax(scores, dim=-1)
o_ref_4d = torch.matmul(attn_weights, v_ref).bfloat16() # (1, H, T, hd)
# O_no_sink = softmax(scores) @ V
O_no_sink = F.softmax(scores, dim=-1) @ v_ref # (1, H, T, hd)
# Correction: O_with_sink = O_no_sink * Z / (Z + exp(sb))
# Z = sum(exp(scores - max)) per head, but more conveniently:
# Z / (Z + exp(sb)) = 1 / (1 + exp(sb) / Z) = 1 / (1 + exp(sb - log(Z)))
# log(Z) = logsumexp(scores)
lse = torch.logsumexp(scores, dim=-1, keepdim=True) # (1, H, T, 1)
# sb shape: (n_h,) → (1, n_h, 1, 1)
sb_4d = sink_bias.reshape(1, n_h, 1, 1)
correction = 1.0 / (1.0 + torch.exp(sb_4d - lse))
o_ref_4d = (O_no_sink * correction).bfloat16() # (1, H, T, hd)
else:
o_ref_4d = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) # (1, H, T, hd)
o_ref = o_ref_4d.squeeze(0) # (n_h, T, hd)