B1 test: fix BF16 reference to use PyTorch SDPA
This commit is contained in:
@@ -36,7 +36,7 @@ def run_mixed_fp8_decode(q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, r
|
||||
|
||||
|
||||
def run_bf16_reference(q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=64):
|
||||
"""Run BF16 reference FMHA by dequantizing FP8 noPE to BF16."""
|
||||
"""Run BF16 reference FMHA using PyTorch SDPA on dequantized KV."""
|
||||
B, H, T, HD = q_bf16.shape
|
||||
NOPE = HD - rope_dim
|
||||
N = k_nope_fp8.shape[0]
|
||||
@@ -53,18 +53,17 @@ def run_bf16_reference(q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rop
|
||||
# V = K for MQA (self-attention decode)
|
||||
v_full = k_full.clone()
|
||||
|
||||
# Run BF16 FMHA
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
q_3d = q_bf16.squeeze(2) # (B, H, HD) -> need (H, T, HD) per batch
|
||||
results = []
|
||||
for b in range(B):
|
||||
q_b = q_3d[b].transpose(0, 1) # (H, 1, HD) -> (H, T=1, HD)
|
||||
# dsv4_attention expects (n_q, T, hd) or (batch, n_q, T, hd)
|
||||
o_b = dsv4_attention(q_b.unsqueeze(0), k_full.unsqueeze(0).unsqueeze(0),
|
||||
v_full.unsqueeze(0).unsqueeze(0), scale)
|
||||
results.append(o_b)
|
||||
o = torch.cat(results, dim=0) # (B, H, T, HD)
|
||||
return o
|
||||
# Run PyTorch SDPA as reference — FP32 math, exact result
|
||||
# q: (B, H, 1, HD), k: (1, 1, N, HD), v: (1, 1, N, HD)
|
||||
q_f = q_bf16.float()
|
||||
k_f = k_full.float().unsqueeze(0).unsqueeze(0) # (1, 1, N, HD)
|
||||
v_f = v_full.float().unsqueeze(0).unsqueeze(0) # (1, 1, N, HD)
|
||||
# Expand k, v for all batches
|
||||
if B > 1:
|
||||
k_f = k_f.expand(B, -1, -1, -1)
|
||||
v_f = v_f.expand(B, -1, -1, -1)
|
||||
o = F.scaled_dot_product_attention(q_f, k_f, v_f, scale=scale) # (B, H, 1, HD)
|
||||
return o.bfloat16()
|
||||
|
||||
|
||||
def test_cosine(N_values, H=128, HD=512, rope_dim=64, B=1, seed=42):
|
||||
|
||||
Reference in New Issue
Block a user