diff --git a/tests/debug_wo_a.py b/tests/debug_wo_a.py deleted file mode 100644 index cf64d627..00000000 --- a/tests/debug_wo_a.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Minimal debug: verify wo_a grouped matmul reference is correct.""" -import torch -import torch.nn.functional as F - -torch.cuda.set_device(0) -torch.manual_seed(42) - -# Small dimensions for debugging -G, HPG, HD, OR = 2, 4, 128, 64 -GI = HPG * HD -T = 4 -DEVICE = "cuda:0" - -o = torch.randn(T, G*HPG, HD, dtype=torch.bfloat16, device=DEVICE) * 2.0 -w = torch.randn(G*OR, GI, dtype=torch.bfloat16, device=DEVICE) * 0.1 - -# Reference: per-group matmul -o_g = o.reshape(T, G, GI) -z_ref = torch.empty(T, G, OR, dtype=torch.bfloat16, device=DEVICE) -for g in range(G): - z_ref[:, g, :] = o_g[:, g, :] @ w[g*OR:(g+1)*OR, :].T - -print(f"z_ref shape={z_ref.shape} amax={z_ref.amax():.4f}") - -# Now test the CuTeDSL runner -from cutedsl.wo_a_grouped_linear import CuTeDSLNvfp4WoA - -runner = CuTeDSLNvfp4WoA( - n_local_groups=G, heads_per_group=HPG, head_dim=HD, - o_lora_rank=OR, max_num_tokens=8, device=DEVICE, -) -runner.set_bf16_weight(w) -runner.finalize_weights() -runner._ensure_initialized() -runner.compute_activation_global_scale(o) - -with torch.no_grad(): - z_out = runner.run(o) - -cos = F.cosine_similarity(z_ref.flatten().unsqueeze(0).float(), z_out.flatten().unsqueeze(0).float()).item() -print(f"cosine={cos:.6f} amax_ref={z_ref.amax():.4f} amax_out={z_out.amax():.4f}") diff --git a/tests/debug_wo_a2.py b/tests/debug_wo_a2.py deleted file mode 100644 index 7144d0e1..00000000 --- a/tests/debug_wo_a2.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Debug: compare NVFP4 grouped GEMM output element-by-element.""" -import torch -import torch.nn.functional as F -import sys, os -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from cutedsl.bridge import quantize_weight_to_nvfp4, quantize_to_nvfp4 - -torch.cuda.set_device(0) -torch.manual_seed(42) - -G, HPG, HD, OR = 2, 4, 128, 64 -GI = HPG * HD # 512 -T = 4 -DEVICE = "cuda:0" - -o = torch.randn(T, G*HPG, HD, dtype=torch.bfloat16, device=DEVICE) * 2.0 -w = torch.randn(G*OR, GI, dtype=torch.bfloat16, device=DEVICE) * 0.1 - -# Reference: per-group BF16 matmul -o_g = o.reshape(T, G, GI) -z_ref = torch.empty(T, G, OR, dtype=torch.bfloat16, device=DEVICE) -for g in range(G): - z_ref[:, g, :] = o_g[:, g, :] @ w[g*OR:(g+1)*OR, :].T - -# Test: quantize/dequantize each weight group and compare -print("=== Weight quantization test ===") -for g in range(G): - w_g = w[g*OR:(g*OR+OR), :] # (OR, GI) - w_gt = w_g.T # (GI, OR) for quantize_weight_to_nvfp4 - w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_gt) - - # Dequantize to BF16 for reference - E2M1_LUT = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6.], - dtype=torch.float32, device=DEVICE) - packed = w_fp4.view(torch.uint8) - lower = E2M1_LUT[(packed & 0x0F).long()] - upper = E2M1_LUT[((packed >> 4) & 0x0F).long()] - K, N = w_gt.shape - unpacked = torch.empty(K, N, dtype=torch.float32, device=DEVICE) - unpacked[:, 0::2] = lower - unpacked[:, 1::2] = upper - K_sf = w_sf.shape[0] - sf_expanded = w_sf.float().repeat_interleave(16, dim=0)[:K, :] - w_dequant = (unpacked * sf_expanded * w_gs).to(torch.bfloat16) - - # Compare - cos = F.cosine_similarity(w_gt.flatten().unsqueeze(0).float(), w_dequant.flatten().unsqueeze(0).float()).item() - print(f" Group {g}: weight quant cos={cos:.6f} w_gt amax={w_gt.amax():.4f} w_dequant amax={w_dequant.amax():.4f}") - -# Test: activation quantization -print("\n=== Activation quantization test ===") -o_flat = o_g.reshape(T * G, GI) -x_fp4, x_sf, gs = quantize_to_nvfp4(o_flat) -# Dequant -packed = x_fp4.view(torch.uint8) -E2M1_LUT = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6.], - dtype=torch.float32, device=DEVICE) -lower = E2M1_LUT[(packed & 0x0F).long()] -upper = E2M1_LUT[((packed >> 4) & 0x0F).long()] -unpacked = torch.empty(T*G, GI, dtype=torch.float32, device=DEVICE) -unpacked[:, 0::2] = lower -unpacked[:, 1::2] = upper -K_sf = x_sf.shape[1] -sf_expanded = x_sf.float().repeat_interleave(16, dim=1)[:T*G, :GI] -x_dequant = (unpacked * sf_expanded * gs).to(torch.bfloat16) -cos = F.cosine_similarity(o_flat.flatten().unsqueeze(0).float(), x_dequant.flatten().unsqueeze(0).float()).item() -print(f" Activation quant cos={cos:.6f} gs={gs:.6f}") - -# Test: the FULL pipeline — quantize weight and activation, then BF16 matmul -print("\n=== Full pipeline (quantize → dequantize → BF16 matmul) ===") -z_qdq = torch.empty(T, G, OR, dtype=torch.bfloat16, device=DEVICE) -for g in range(G): - w_g = w[g*OR:(g*OR+OR), :].T # (GI, OR) - w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g) - # Dequant - packed = w_fp4.view(torch.uint8) - lower = E2M1_LUT[(packed & 0x0F).long()] - upper = E2M1_LUT[((packed >> 4) & 0x0F).long()] - K, N = w_g.shape - unpacked = torch.empty(K, N, dtype=torch.float32, device=DEVICE) - unpacked[:, 0::2] = lower - unpacked[:, 1::2] = upper - K_sf = w_sf.shape[0] - sf_expanded = w_sf.float().repeat_interleave(16, dim=0)[:K, :] - w_dequant = (unpacked * sf_expanded * w_gs).to(torch.bfloat16) - - # Quantize activation for this group - act = o_g[:, g, :] # (T, GI) - a_fp4, a_sf, a_gs = quantize_to_nvfp4(act) - packed = a_fp4.view(torch.uint8) - lower = E2M1_LUT[(packed & 0x0F).long()] - upper = E2M1_LUT[((packed >> 4) & 0x0F).long()] - unpacked = torch.empty(T, GI, dtype=torch.float32, device=DEVICE) - unpacked[:, 0::2] = lower - unpacked[:, 1::2] = upper - K_sf = a_sf.shape[1] - sf_expanded = a_sf.float().repeat_interleave(16, dim=1)[:T, :GI] - a_dequant = (unpacked * sf_expanded * a_gs).to(torch.bfloat16) - - z_qdq[:, g, :] = a_dequant @ w_dequant - -cos = F.cosine_similarity(z_ref.flatten().unsqueeze(0).float(), z_qdq.flatten().unsqueeze(0).float()).item() -print(f" QDQ vs BF16: cosine={cos:.6f}") diff --git a/tests/debug_wo_a3.py b/tests/debug_wo_a3.py deleted file mode 100644 index 8245f422..00000000 --- a/tests/debug_wo_a3.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Debug: diagnose wo_a grouped GEMM issue step by step.""" -import torch -import torch.nn.functional as F -import sys, os -sys.path.insert(0, "/root/nvfp4-megamoe-kernel") - -from cutedsl.wo_a_grouped_linear import CuTeDSLNvfp4WoA -from cutedsl.bridge import quantize_weight_to_nvfp4, quantize_to_nvfp4, quantize_activation_nvfp4 - -torch.cuda.set_device(0) -torch.manual_seed(42) - -# Small dimensions -G, HPG, HD, OR = 2, 4, 128, 64 -GI = HPG * HD # 512 -T = 4 -DEVICE = "cuda:0" - -o = torch.randn(T, G*HPG, HD, dtype=torch.bfloat16, device=DEVICE) * 2.0 -w = torch.randn(G*OR, GI, dtype=torch.bfloat16, device=DEVICE) * 0.1 - -# Reference: per-group BF16 matmul -o_g = o.reshape(T, G, GI) -z_ref = torch.empty(T, G, OR, dtype=torch.bfloat16, device=DEVICE) -for g in range(G): - z_ref[:, g, :] = o_g[:, g, :] @ w[g*OR:(g+1)*OR, :].T -print(f"z_ref amax={z_ref.amax():.4f} shape={z_ref.shape}") -print(f"z_ref[0, 0, :8] = {z_ref[0, 0, :8]}") - -# Step 1: verify weight quantization per-group -print("\n=== Weight quant ===") -for g in range(G): - w_g = w[g*OR:(g+1)*OR, :].T # (GI, OR) - w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g) - print(f" Group {g}: w_g shape={w_g.shape} w_fp4 shape={w_fp4.shape} w_sf shape={w_sf.shape} gs={w_gs:.6f}") - -# Step 2: test runner directly (bypass custom op) -runner = CuTeDSLNvfp4WoA( - n_local_groups=G, heads_per_group=HPG, head_dim=HD, - o_lora_rank=OR, max_num_tokens=8, device=DEVICE, -) -runner.set_bf16_weight(w) -runner.finalize_weights() -runner._ensure_initialized() - -# Compute activation gs -with torch.no_grad(): - _, _, gs = quantize_to_nvfp4(o_g[:, 0, :]) # use first group's activation -print(f"\nActivation gs from sample: {gs:.6f}") -print(f"Runner gs: {runner._activation_global_scale:.6f}") - -runner._activation_global_scale = gs # use the right one - -# Call _run_impl directly -with torch.no_grad(): - z_out = runner._run_impl(o) -print(f"\nz_out shape={z_out.shape} amax={z_out.amax():.4f}") -print(f"z_out[0, 0, :8] = {z_out[0, 0, :8]}") - -# Per-group comparison -for g in range(G): - cos = F.cosine_similarity(z_ref[:, g, :].flatten().unsqueeze(0).float(), - z_out[:, g, :].flatten().unsqueeze(0).float()).item() - print(f" Group {g}: cosine={cos:.6f} ref_amax={z_ref[:, g, :].amax():.4f} out_amax={z_out[:, g, :].amax():.4f}") - -cos = F.cosine_similarity(z_ref.flatten().unsqueeze(0).float(), z_out.flatten().unsqueeze(0).float()).item() -print(f"\nOverall cosine={cos:.6f}")