Files
nvfp4-megamoe-kernel/tests/unit/test_d15_rescale_debug.py

117 lines
6.6 KiB
Python

"""
D1.5 Debug: Test s_k=256 in-kernel rescale with diagnostics.
Minimal test to isolate the TMEM round-trip vs barrier issue.
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def reference_attention(q, k, v, scale):
qf = q.float(); kf = k.float()
attn = qf @ kf.T * scale
attn_max = attn.max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(attn - attn_max)
return attn_exp @ v.float()
def test():
hd = 64; m = 128; scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# Test 1: s_k=128 baseline
k1 = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda')
v1 = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
c1 = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
lse1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
kernel1 = FmhaKernel(head_dim=hd, s_k=128, use_smem_p=False, normalize=False)
pv_n_tile = kernel1.pv_n_tile
v1t = v1[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK1 = ct.from_dlpack(k1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k1))
mV1 = ct.from_dlpack(v1t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v1t))
mC1 = ct.from_dlpack(c1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c1))
mLSE1 = ct.from_dlpack(lse1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse1))
mRS1 = ct.from_dlpack(rs1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs1))
compiled1 = cute.compile(kernel1, mQ, mK1, mV1, mC1, stream, mLSE1, row_sums=mRS1)
compiled1(mQ, mK1, mV1, mC1, stream, mLSE1, row_sums=mRS1)
torch.cuda.synchronize()
ref1 = reference_attention(q[:, :, 0], k1[:, :, 0], v1, scale)
cos1 = torch.nn.functional.cosine_similarity(c1[:, :, 0].float().flatten().unsqueeze(0), ref1.flatten().unsqueeze(0)).item()
print(f's_k=128 baseline: cos={cos1:.6f} {"PASS" if cos1 >= 0.999 else "FAIL"}', flush=True)
# Test 2: s_k=256 with in-kernel rescale
k2 = torch.randn(256, hd, 1, dtype=torch.bfloat16, device='cuda')
v2 = torch.randn(256, hd, dtype=torch.bfloat16, device='cuda')
c2 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
kernel2 = FmhaKernel(head_dim=hd, s_k=256, use_smem_p=False, normalize=False)
v2t = v2[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
mK2 = ct.from_dlpack(k2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2))
mV2 = ct.from_dlpack(v2t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2t))
mC2 = ct.from_dlpack(c2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c2))
mLSE2 = ct.from_dlpack(lse2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse2))
mRS2 = ct.from_dlpack(rs2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs2))
compiled2 = cute.compile(kernel2, mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2)
compiled2(mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2)
torch.cuda.synchronize()
ref2 = reference_attention(q[:, :, 0], k2[:, :, 0], v2, scale)
out2 = c2[:, :, 0].float()
cos2 = torch.nn.functional.cosine_similarity(out2.flatten().unsqueeze(0), ref2.flatten().unsqueeze(0)).item()
# Per-element stats
diff2 = (out2 - ref2).abs()
max_rel_err = (diff2 / ref2.abs().clamp(min=1e-6)).max().item()
print(f's_k=256 in-kernel rescale: cos={cos2:.6f} max_rel_err={max_rel_err:.4f} {"PASS" if cos2 >= 0.999 else "FAIL"}', flush=True)
# Print LSE and row_sums for debugging
print(f' LSE range: [{lse2[:, 0, 0].min().item():.4f}, {lse2[:, 0, 0].max().item():.4f}]', flush=True)
print(f' row_sums range: [{rs2[:, 0, 0].min().item():.4f}, {rs2[:, 0, 0].max().item():.4f}]', flush=True)
# Test 3: Python KV merge for comparison
c_s0 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_s0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs_s0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
v2_0 = v2[:128, 0:pv_n_tile].contiguous().unsqueeze(-1)
mK2_0 = ct.from_dlpack(k2[:128]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2[:128]))
mV2_0 = ct.from_dlpack(v2_0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2_0))
mC_s0 = ct.from_dlpack(c_s0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_s0))
mLSE_s0 = ct.from_dlpack(lse_s0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_s0))
mRS_s0 = ct.from_dlpack(rs_s0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs_s0))
compiled_s0 = cute.compile(kernel1, mQ, mK2_0, mV2_0, mC_s0, stream, mLSE_s0, row_sums=mRS_s0)
compiled_s0(mQ, mK2_0, mV2_0, mC_s0, stream, mLSE_s0, row_sums=mRS_s0)
c_s1 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_s1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs_s1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
v2_1 = v2[128:, 0:pv_n_tile].contiguous().unsqueeze(-1)
mK2_1 = ct.from_dlpack(k2[128:]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2[128:]))
mV2_1 = ct.from_dlpack(v2_1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2_1))
mC_s1 = ct.from_dlpack(c_s1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_s1))
mLSE_s1 = ct.from_dlpack(lse_s1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_s1))
mRS_s1 = ct.from_dlpack(rs_s1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs_s1))
compiled_s1 = cute.compile(kernel1, mQ, mK2_1, mV2_1, mC_s1, stream, mLSE_s1, row_sums=mRS_s1)
compiled_s1(mQ, mK2_1, mV2_1, mC_s1, stream, mLSE_s1, row_sums=mRS_s1)
torch.cuda.synchronize()
o0 = c_s0[:, :, 0].float(); o1 = c_s1[:, :, 0].float()
r0 = rs_s0[:, 0, 0].float(); r1 = rs_s1[:, 0, 0].float()
l0 = lse_s0[:, 0, 0].float(); l1 = lse_s1[:, 0, 0].float()
o0_norm = o0 / r0.unsqueeze(1).clamp(min=1e-30)
o1_norm = o1 / r1.unsqueeze(1).clamp(min=1e-30)
w0 = torch.exp(l0).unsqueeze(1); w1 = torch.exp(l1).unsqueeze(1)
oracle = (w0 * o0_norm + w1 * o1_norm) / (w0 + w1)
cos_oracle = torch.nn.functional.cosine_similarity(oracle.flatten().unsqueeze(0), ref2.flatten().unsqueeze(0)).item()
print(f'Python KV merge: cos={cos_oracle:.6f}', flush=True)
if __name__ == '__main__':
test()