diff --git a/tests/unit/test_decode_fmha_layer.py b/tests/unit/test_decode_fmha_layer.py index f9e8fca2..28cb1110 100644 --- a/tests/unit/test_decode_fmha_layer.py +++ b/tests/unit/test_decode_fmha_layer.py @@ -439,32 +439,30 @@ def main(): continue o_prod = o_prod_4d.squeeze(0) # (n_h, T, hd) - # 7. Reference: dequantize mixed KV to BF16, run SDPA WITH sink bias + # 7. Reference: dequantize mixed KV to BF16, run reference with sink bias nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float() kv_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, hd) k_4d = kv_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) # (1, 1, N, hd) v_4d = k_4d.clone() if sink_bias is not None: - # Reference with sink bias: compute attention scores manually and add sink as logit bias - # DSV4 sink merge = single softmax over [S_comp, S_swa + attn_sink] - # The sink bias is added as a logit to the LAST position (the sink position) - # In our SWA-only case, the sink is the last SWA entry. - # Formula: attn_weights = softmax(Q @ K^T * scale + sink_bias_on_last_pos) - # For per-head sink: sink_bias shape is (n_h,) + # DSV4 sink is denominator-only: O = sum(P*V) / (sum(P) + exp(sb)) + # where P = softmax(QK*scale). The sink has NO V contribution. + # Reference: compute O_no_sink, then scale by correction factor. q_ref = q_4d.float() # (1, H, T, hd) k_ref = k_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd) v_ref = v_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd) scores = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale # (1, H, T, N) - # Add sink bias to the last position - # sink_bias: (n_h,) → (1, n_h, 1, 1) broadcast to (1, n_h, T, N) on last N dim - # The production kernel adds sink_bias / scale_softmax to the last SWA position - # Actually, the sink bias is added as-is to the logit at the sink position - sb_4d = sink_bias.reshape(1, n_h, 1, 1) # (1, H, 1, 1) - # The sink position is the LAST position in the KV sequence - scores[:, :, :, -1:] = scores[:, :, :, -1:] + sb_4d - # Softmax - attn_weights = F.softmax(scores, dim=-1) - o_ref_4d = torch.matmul(attn_weights, v_ref).bfloat16() # (1, H, T, hd) + # O_no_sink = softmax(scores) @ V + O_no_sink = F.softmax(scores, dim=-1) @ v_ref # (1, H, T, hd) + # Correction: O_with_sink = O_no_sink * Z / (Z + exp(sb)) + # Z = sum(exp(scores - max)) per head, but more conveniently: + # Z / (Z + exp(sb)) = 1 / (1 + exp(sb) / Z) = 1 / (1 + exp(sb - log(Z))) + # log(Z) = logsumexp(scores) + lse = torch.logsumexp(scores, dim=-1, keepdim=True) # (1, H, T, 1) + # sb shape: (n_h,) → (1, n_h, 1, 1) + sb_4d = sink_bias.reshape(1, n_h, 1, 1) + correction = 1.0 / (1.0 + torch.exp(sb_4d - lse)) + o_ref_4d = (O_no_sink * correction).bfloat16() # (1, H, T, hd) else: o_ref_4d = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) # (1, H, T, hd) o_ref = o_ref_4d.squeeze(0) # (n_h, T, hd)