rewrite D2 regression test: match existing Stage D1 test pattern with cute.compile + PV tiles

This commit is contained in:
2026-05-25 17:11:59 +00:00
parent 06cb800242
commit 673825c242

View File

@@ -1,54 +1,166 @@
"""Quick test: n_h=1 regression after grid changes."""
"""
FMHA D2 regression test (matches existing test pattern).
Uses the same cute.compile + PV tile iteration as test_fmha_v3_stage_d1.py.
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_regression.py
"""
import torch
import math
import cutlass
import cutlass.cute as cute
import cuda.bindings.driver as cuda
import cutlass.torch as ct
from cutlass import Float32, BFloat16
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test():
print("=== n_h=1 regression (hd=64, s_k=128) ===")
torch.manual_seed(42)
M, s_k, hd = 128, 128, 64
scale = 1.0 / math.sqrt(hd)
q = torch.randn(M, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
o = torch.zeros(M, hd, 1, dtype=torch.bfloat16, device='cuda')
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=False)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_c = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
o_c = ct.from_dlpack(o).mark_layout_dynamic(leading_dim=ct.get_leading_dim(o))
lse = torch.zeros(M, dtype=torch.float32, device='cuda')
lse_c = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
fmha(q_c, k_c, v_c, o_c, stream, lse_c)
# External normalization using LSE
row_sum = lse.exp()
o_norm = o[:,:,0] / row_sum.unsqueeze(-1)
# Reference
scores = torch.matmul(q[:,:,0].float(), k[:,:,0].float().T) * scale
def reference_fmha(q, k, v, scale):
"""FP32 reference: q (M, hd), k (s_k, hd), v (s_k, hd) → o (M, hd)"""
scores = torch.matmul(q.float(), k.float().T) * scale
max_s = scores.max(dim=-1, keepdim=True).values
exp_s = (scores - max_s).exp()
sum_s = exp_s.sum(dim=-1, keepdim=True)
p = exp_s / sum_s
ref = torch.matmul(p, v[:,:,0].float()).to(torch.bfloat16)
o = torch.matmul(p, v.float())
return o.to(torch.bfloat16), (sum_s.log() + max_s)
def test_d2_regression():
"""Regression test matching existing Stage D1 pattern."""
print("\n=== Regression test (hd=64, s_k=128) ===")
torch.manual_seed(42)
m = 128; n_kv = 128; hd = 64
scale = 1.0 / math.sqrt(hd)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda')
kernel = FmhaKernel(head_dim=hd, s_k=n_kv, use_smem_p=False)
pv_n_tile = kernel.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Compile with first PV tile
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
# Run PV tiles
o_unnorm = torch.zeros(m, hd, dtype=torch.float32, device='cuda')
for pv in range(n_pv_tiles):
v_tile = v[:, pv*pv_n_tile:(pv+1)*pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile.zero_()
lse_tensor.zero_()
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
compiled(mQ, mK, mV, mC, stream, mLSE)
o_unnorm[:, pv*pv_n_tile:(pv+1)*pv_n_tile] = c_tile[:,:,0].float()
# External normalization using LSE
lse = lse_tensor[:,0,0] # (m,)
row_sum = lse.exp()
o_norm = o_unnorm / row_sum.unsqueeze(-1)
o_bf16 = o_norm.to(torch.bfloat16)
# Reference
ref, _ = reference_fmha(q[:,:,0], k[:,:,0], v, scale)
cos = torch.nn.functional.cosine_similarity(
o_norm.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0)
o_bf16.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0)
).item()
print(f" cos (ext norm) = {cos:.6f}")
if cos >= 0.99:
print(" ✅ PASS")
else:
print(f" ❌ FAIL")
print(f" cos = {cos:.6f}")
assert cos >= 0.99, f"cosine too low: {cos}"
print(" ✅ PASS")
def test_d2_headpacked_128():
"""n_h=128, T=1 (Pro decode): M=128, heads packed into M."""
print("\n=== n_h=128, T=1 (Pro decode, hd=64) ===")
torch.manual_seed(42)
n_h, T, s_k, hd = 128, 1, 128, 64
scale = 1.0 / math.sqrt(hd)
# Per-head Q
q_heads = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
# Pack heads into M: (n_h*T, hd) → (128, 64, 1)
q = q_heads.reshape(n_h * T, hd).unsqueeze(-1)
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False)
pv_n_tile = kernel.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
v_tile = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
c_tile = torch.zeros(n_h * T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor = torch.zeros(n_h * T, 1, 1, dtype=torch.float32, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
o_unnorm = torch.zeros(n_h * T, hd, dtype=torch.float32, device='cuda')
for pv in range(n_pv_tiles):
v_tile = v[:, pv*pv_n_tile:(pv+1)*pv_n_tile].contiguous().unsqueeze(-1)
c_tile.zero_()
lse_tensor.zero_()
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
compiled(mQ, mK, mV, mC, stream, mLSE)
o_unnorm[:, pv*pv_n_tile:(pv+1)*pv_n_tile] = c_tile[:,:,0].float()
lse = lse_tensor[:,0,0]
row_sum = lse.exp()
o_norm = o_unnorm / row_sum.unsqueeze(-1)
o_bf16 = o_norm.to(torch.bfloat16)
# Per-head reference
o_ref = torch.zeros(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
for h in range(n_h):
o_ref[h, 0], _ = reference_fmha(q_heads[h], k[:,:,0], v, scale)
o_ref_flat = o_ref.reshape(n_h * T, hd)
cos = torch.nn.functional.cosine_similarity(
o_bf16.flatten().float().unsqueeze(0), o_ref_flat.flatten().float().unsqueeze(0)
).item()
print(f" cos = {cos:.6f}")
assert cos >= 0.99, f"cosine too low: {cos}"
print(" ✅ PASS")
def test():
print("=== D2: Head-Packed FMHA ===")
test_d2_regression()
test_d2_headpacked_128()
print("\n=== ALL TESTS PASSED ===")
if __name__ == '__main__':