test: correct sink bias reference — denominator-only, no V contribution
This commit is contained in:
@@ -439,32 +439,30 @@ def main():
|
||||
continue
|
||||
o_prod = o_prod_4d.squeeze(0) # (n_h, T, hd)
|
||||
|
||||
# 7. Reference: dequantize mixed KV to BF16, run SDPA WITH sink bias
|
||||
# 7. Reference: dequantize mixed KV to BF16, run reference with sink bias
|
||||
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
|
||||
kv_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, hd)
|
||||
k_4d = kv_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) # (1, 1, N, hd)
|
||||
v_4d = k_4d.clone()
|
||||
if sink_bias is not None:
|
||||
# Reference with sink bias: compute attention scores manually and add sink as logit bias
|
||||
# DSV4 sink merge = single softmax over [S_comp, S_swa + attn_sink]
|
||||
# The sink bias is added as a logit to the LAST position (the sink position)
|
||||
# In our SWA-only case, the sink is the last SWA entry.
|
||||
# Formula: attn_weights = softmax(Q @ K^T * scale + sink_bias_on_last_pos)
|
||||
# For per-head sink: sink_bias shape is (n_h,)
|
||||
# DSV4 sink is denominator-only: O = sum(P*V) / (sum(P) + exp(sb))
|
||||
# where P = softmax(QK*scale). The sink has NO V contribution.
|
||||
# Reference: compute O_no_sink, then scale by correction factor.
|
||||
q_ref = q_4d.float() # (1, H, T, hd)
|
||||
k_ref = k_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd)
|
||||
v_ref = v_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd)
|
||||
scores = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale # (1, H, T, N)
|
||||
# Add sink bias to the last position
|
||||
# sink_bias: (n_h,) → (1, n_h, 1, 1) broadcast to (1, n_h, T, N) on last N dim
|
||||
# The production kernel adds sink_bias / scale_softmax to the last SWA position
|
||||
# Actually, the sink bias is added as-is to the logit at the sink position
|
||||
sb_4d = sink_bias.reshape(1, n_h, 1, 1) # (1, H, 1, 1)
|
||||
# The sink position is the LAST position in the KV sequence
|
||||
scores[:, :, :, -1:] = scores[:, :, :, -1:] + sb_4d
|
||||
# Softmax
|
||||
attn_weights = F.softmax(scores, dim=-1)
|
||||
o_ref_4d = torch.matmul(attn_weights, v_ref).bfloat16() # (1, H, T, hd)
|
||||
# O_no_sink = softmax(scores) @ V
|
||||
O_no_sink = F.softmax(scores, dim=-1) @ v_ref # (1, H, T, hd)
|
||||
# Correction: O_with_sink = O_no_sink * Z / (Z + exp(sb))
|
||||
# Z = sum(exp(scores - max)) per head, but more conveniently:
|
||||
# Z / (Z + exp(sb)) = 1 / (1 + exp(sb) / Z) = 1 / (1 + exp(sb - log(Z)))
|
||||
# log(Z) = logsumexp(scores)
|
||||
lse = torch.logsumexp(scores, dim=-1, keepdim=True) # (1, H, T, 1)
|
||||
# sb shape: (n_h,) → (1, n_h, 1, 1)
|
||||
sb_4d = sink_bias.reshape(1, n_h, 1, 1)
|
||||
correction = 1.0 / (1.0 + torch.exp(sb_4d - lse))
|
||||
o_ref_4d = (O_no_sink * correction).bfloat16() # (1, H, T, hd)
|
||||
else:
|
||||
o_ref_4d = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) # (1, H, T, hd)
|
||||
o_ref = o_ref_4d.squeeze(0) # (n_h, T, hd)
|
||||
|
||||
Reference in New Issue
Block a user