test: add sink bias to reference SDPA in decode FMHA comparison
This commit is contained in:
@@ -439,12 +439,34 @@ def main():
|
||||
continue
|
||||
o_prod = o_prod_4d.squeeze(0) # (n_h, T, hd)
|
||||
|
||||
# 7. Reference: dequantize mixed KV to BF16, run SDPA
|
||||
# 7. Reference: dequantize mixed KV to BF16, run SDPA WITH sink bias
|
||||
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
|
||||
kv_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, hd)
|
||||
k_4d = kv_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) # (1, 1, N, hd)
|
||||
v_4d = k_4d.clone()
|
||||
o_ref_4d = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) # (1, H, T, hd)
|
||||
if sink_bias is not None:
|
||||
# Reference with sink bias: compute attention scores manually and add sink as logit bias
|
||||
# DSV4 sink merge = single softmax over [S_comp, S_swa + attn_sink]
|
||||
# The sink bias is added as a logit to the LAST position (the sink position)
|
||||
# In our SWA-only case, the sink is the last SWA entry.
|
||||
# Formula: attn_weights = softmax(Q @ K^T * scale + sink_bias_on_last_pos)
|
||||
# For per-head sink: sink_bias shape is (n_h,)
|
||||
q_ref = q_4d.float() # (1, H, T, hd)
|
||||
k_ref = k_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd)
|
||||
v_ref = v_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd)
|
||||
scores = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale # (1, H, T, N)
|
||||
# Add sink bias to the last position
|
||||
# sink_bias: (n_h,) → (1, n_h, 1, 1) broadcast to (1, n_h, T, N) on last N dim
|
||||
# The production kernel adds sink_bias / scale_softmax to the last SWA position
|
||||
# Actually, the sink bias is added as-is to the logit at the sink position
|
||||
sb_4d = sink_bias.reshape(1, n_h, 1, 1) # (1, H, 1, 1)
|
||||
# The sink position is the LAST position in the KV sequence
|
||||
scores[:, :, :, -1:] = scores[:, :, :, -1:] + sb_4d
|
||||
# Softmax
|
||||
attn_weights = F.softmax(scores, dim=-1)
|
||||
o_ref_4d = torch.matmul(attn_weights, v_ref).bfloat16() # (1, H, T, hd)
|
||||
else:
|
||||
o_ref_4d = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) # (1, H, T, hd)
|
||||
o_ref = o_ref_4d.squeeze(0) # (n_h, T, hd)
|
||||
|
||||
# 8. Compare
|
||||
|
||||
Reference in New Issue
Block a user