From 0cea0b33ffcdc38732f962c4bb608cfffcc04571 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 00:07:38 +0000 Subject: [PATCH] B1 test: fix BF16 reference to use PyTorch SDPA --- tests/unit/test_fmha_mixed_fp8_cosine.py | 25 ++++++++++++------------ 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/unit/test_fmha_mixed_fp8_cosine.py b/tests/unit/test_fmha_mixed_fp8_cosine.py index 5ab87e67..d97b8f95 100644 --- a/tests/unit/test_fmha_mixed_fp8_cosine.py +++ b/tests/unit/test_fmha_mixed_fp8_cosine.py @@ -36,7 +36,7 @@ def run_mixed_fp8_decode(q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, r def run_bf16_reference(q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=64): - """Run BF16 reference FMHA by dequantizing FP8 noPE to BF16.""" + """Run BF16 reference FMHA using PyTorch SDPA on dequantized KV.""" B, H, T, HD = q_bf16.shape NOPE = HD - rope_dim N = k_nope_fp8.shape[0] @@ -53,18 +53,17 @@ def run_bf16_reference(q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rop # V = K for MQA (self-attention decode) v_full = k_full.clone() - # Run BF16 FMHA - from dsv4.kernels.attention.production import dsv4_attention - q_3d = q_bf16.squeeze(2) # (B, H, HD) -> need (H, T, HD) per batch - results = [] - for b in range(B): - q_b = q_3d[b].transpose(0, 1) # (H, 1, HD) -> (H, T=1, HD) - # dsv4_attention expects (n_q, T, hd) or (batch, n_q, T, hd) - o_b = dsv4_attention(q_b.unsqueeze(0), k_full.unsqueeze(0).unsqueeze(0), - v_full.unsqueeze(0).unsqueeze(0), scale) - results.append(o_b) - o = torch.cat(results, dim=0) # (B, H, T, HD) - return o + # Run PyTorch SDPA as reference — FP32 math, exact result + # q: (B, H, 1, HD), k: (1, 1, N, HD), v: (1, 1, N, HD) + q_f = q_bf16.float() + k_f = k_full.float().unsqueeze(0).unsqueeze(0) # (1, 1, N, HD) + v_f = v_full.float().unsqueeze(0).unsqueeze(0) # (1, 1, N, HD) + # Expand k, v for all batches + if B > 1: + k_f = k_f.expand(B, -1, -1, -1) + v_f = v_f.expand(B, -1, -1, -1) + o = F.scaled_dot_product_attention(q_f, k_f, v_f, scale=scale) # (B, H, 1, HD) + return o.bfloat16() def test_cosine(N_values, H=128, HD=512, rope_dim=64, B=1, seed=42):