Files
nvfp4-megamoe-kernel/tests/unit/test_d1_3_unnorm_debug.py

97 lines
4.0 KiB
Python

"""
D1.3 SMEM-P: Debug why hd>64 fails.
Test: compute raw PV (before O normalization) at hd=128 with SMEM-P
and compare against FP32 oracle.
Also test: hd=64 with SMEM-P but skip O normalization to isolate the error.
"""
import torch, math
import cutlass, cutlass.cute as cute
from cutlass import Float32, BFloat16
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test_unnormalized(hd, use_smem_p, s_k=128):
"""Test with normalize=False to get raw O + LSE, isolate the P write error."""
pv_n = min(hd, 256)
q = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, pv_n, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(128, pv_n, 1, dtype=torch.bfloat16, device='cuda')
lse = torch.zeros(1, dtype=torch.float32, device='cuda')
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
vf = v.float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
attn_softmax = torch.softmax(attn, dim=-1)
ref = attn_softmax @ vf # normalized reference
ref_unnorm = attn_softmax * attn_softmax.shape[-1] # just for debugging
kern = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=use_smem_p, normalize=False)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
v_tile = v.unsqueeze(-1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
mode = "SMEM-P" if use_smem_p else "TMEM-P"
print(f'Compiling hd={hd} {mode} normalize=False...', flush=True)
compiled = cute.compile(kern, mQ, mK, mV, mC, stream, mLSE)
compiled(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
out = c[:, :, 0].float()
lse_val = lse.item()
# The un-normalized output should be: O_unnorm = exp(lse) * O_norm
# So O_norm = O_unnorm / exp(lse)
if lse_val != 0 and not math.isnan(lse_val) and not math.isinf(lse_val):
out_norm = out / math.exp(lse_val)
cos = torch.nn.functional.cosine_similarity(
out_norm.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
max_abs = (out_norm - ref).abs().max().item()
print(f' hd={hd} {mode} unnorm: cos={cos:.6f} max_abs={max_abs:.6f}')
print(f' LSE={lse_val:.6f} exp(lse)={math.exp(lse_val):.6f}')
print(f' out range: [{out.min().item():.4f}, {out.max().item():.4f}]')
print(f' ref range: [{ref.min().item():.4f}, {ref.max().item():.4f}]')
else:
print(f' hd={hd} {mode} unnorm: INVALID LSE={lse_val}')
print(f' out has NaN: {torch.isnan(out).any().item()}')
print(f' out range: [{out.min().item():.4f}, {out.max().item():.4f}]')
# Also test normalized
kern2 = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=use_smem_p, normalize=True)
c2 = torch.zeros(128, pv_n, 1, dtype=torch.bfloat16, device='cuda')
mC2 = ct.from_dlpack(c2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c2))
compiled2 = cute.compile(kern2, mQ, mK, mV, mC2, stream)
compiled2(mQ, mK, mV, mC2, stream)
torch.cuda.synchronize()
out2 = c2[:, :, 0].float()
cos2 = torch.nn.functional.cosine_similarity(
out2.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
print(f' hd={hd} {mode} normalized: cos={cos2:.6f}')
print()
if __name__ == '__main__':
print("=== SMEM-P Debug: Unnormalized vs Normalized ===\n")
# hd=64 baseline
test_unnormalized(64, use_smem_p=False)
test_unnormalized(64, use_smem_p=True)
# hd=128
test_unnormalized(128, use_smem_p=True)
# hd=256
test_unnormalized(256, use_smem_p=True)