diff --git a/tests/unit/test_decode_fmha_layer.py b/tests/unit/test_decode_fmha_layer.py index 99869977..f9e8fca2 100644 --- a/tests/unit/test_decode_fmha_layer.py +++ b/tests/unit/test_decode_fmha_layer.py @@ -439,12 +439,34 @@ def main(): continue o_prod = o_prod_4d.squeeze(0) # (n_h, T, hd) - # 7. Reference: dequantize mixed KV to BF16, run SDPA + # 7. Reference: dequantize mixed KV to BF16, run SDPA WITH sink bias nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float() kv_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, hd) k_4d = kv_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) # (1, 1, N, hd) v_4d = k_4d.clone() - o_ref_4d = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) # (1, H, T, hd) + if sink_bias is not None: + # Reference with sink bias: compute attention scores manually and add sink as logit bias + # DSV4 sink merge = single softmax over [S_comp, S_swa + attn_sink] + # The sink bias is added as a logit to the LAST position (the sink position) + # In our SWA-only case, the sink is the last SWA entry. + # Formula: attn_weights = softmax(Q @ K^T * scale + sink_bias_on_last_pos) + # For per-head sink: sink_bias shape is (n_h,) + q_ref = q_4d.float() # (1, H, T, hd) + k_ref = k_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd) + v_ref = v_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd) + scores = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale # (1, H, T, N) + # Add sink bias to the last position + # sink_bias: (n_h,) → (1, n_h, 1, 1) broadcast to (1, n_h, T, N) on last N dim + # The production kernel adds sink_bias / scale_softmax to the last SWA position + # Actually, the sink bias is added as-is to the logit at the sink position + sb_4d = sink_bias.reshape(1, n_h, 1, 1) # (1, H, 1, 1) + # The sink position is the LAST position in the KV sequence + scores[:, :, :, -1:] = scores[:, :, :, -1:] + sb_4d + # Softmax + attn_weights = F.softmax(scores, dim=-1) + o_ref_4d = torch.matmul(attn_weights, v_ref).bfloat16() # (1, H, T, hd) + else: + o_ref_4d = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) # (1, H, T, hd) o_ref = o_ref_4d.squeeze(0) # (n_h, T, hd) # 8. Compare