B1 test: fix BF16 reference to use PyTorch SDPA

This commit is contained in:
2026-06-03 00:07:38 +00:00
parent a51d19a7fc
commit 0cea0b33ff

View File

@@ -36,7 +36,7 @@ def run_mixed_fp8_decode(q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, r
def run_bf16_reference(q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=64):
"""Run BF16 reference FMHA by dequantizing FP8 noPE to BF16."""
"""Run BF16 reference FMHA using PyTorch SDPA on dequantized KV."""
B, H, T, HD = q_bf16.shape
NOPE = HD - rope_dim
N = k_nope_fp8.shape[0]
@@ -53,18 +53,17 @@ def run_bf16_reference(q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rop
# V = K for MQA (self-attention decode)
v_full = k_full.clone()
# Run BF16 FMHA
from dsv4.kernels.attention.production import dsv4_attention
q_3d = q_bf16.squeeze(2) # (B, H, HD) -> need (H, T, HD) per batch
results = []
for b in range(B):
q_b = q_3d[b].transpose(0, 1) # (H, 1, HD) -> (H, T=1, HD)
# dsv4_attention expects (n_q, T, hd) or (batch, n_q, T, hd)
o_b = dsv4_attention(q_b.unsqueeze(0), k_full.unsqueeze(0).unsqueeze(0),
v_full.unsqueeze(0).unsqueeze(0), scale)
results.append(o_b)
o = torch.cat(results, dim=0) # (B, H, T, HD)
return o
# Run PyTorch SDPA as reference — FP32 math, exact result
# q: (B, H, 1, HD), k: (1, 1, N, HD), v: (1, 1, N, HD)
q_f = q_bf16.float()
k_f = k_full.float().unsqueeze(0).unsqueeze(0) # (1, 1, N, HD)
v_f = v_full.float().unsqueeze(0).unsqueeze(0) # (1, 1, N, HD)
# Expand k, v for all batches
if B > 1:
k_f = k_f.expand(B, -1, -1, -1)
v_f = v_f.expand(B, -1, -1, -1)
o = F.scaled_dot_product_attention(q_f, k_f, v_f, scale=scale) # (B, H, 1, HD)
return o.bfloat16()
def test_cosine(N_values, H=128, HD=512, rope_dim=64, B=1, seed=42):