68 lines
2.4 KiB
Python
68 lines
2.4 KiB
Python
"""Debug: diagnose wo_a grouped GEMM issue step by step."""
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import sys, os
|
|
sys.path.insert(0, "/root/nvfp4-megamoe-kernel")
|
|
|
|
from cutedsl.wo_a_grouped_linear import CuTeDSLNvfp4WoA
|
|
from cutedsl.bridge import quantize_weight_to_nvfp4, quantize_to_nvfp4, quantize_activation_nvfp4
|
|
|
|
torch.cuda.set_device(0)
|
|
torch.manual_seed(42)
|
|
|
|
# Small dimensions
|
|
G, HPG, HD, OR = 2, 4, 128, 64
|
|
GI = HPG * HD # 512
|
|
T = 4
|
|
DEVICE = "cuda:0"
|
|
|
|
o = torch.randn(T, G*HPG, HD, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
|
w = torch.randn(G*OR, GI, dtype=torch.bfloat16, device=DEVICE) * 0.1
|
|
|
|
# Reference: per-group BF16 matmul
|
|
o_g = o.reshape(T, G, GI)
|
|
z_ref = torch.empty(T, G, OR, dtype=torch.bfloat16, device=DEVICE)
|
|
for g in range(G):
|
|
z_ref[:, g, :] = o_g[:, g, :] @ w[g*OR:(g+1)*OR, :].T
|
|
print(f"z_ref amax={z_ref.amax():.4f} shape={z_ref.shape}")
|
|
print(f"z_ref[0, 0, :8] = {z_ref[0, 0, :8]}")
|
|
|
|
# Step 1: verify weight quantization per-group
|
|
print("\n=== Weight quant ===")
|
|
for g in range(G):
|
|
w_g = w[g*OR:(g+1)*OR, :].T # (GI, OR)
|
|
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g)
|
|
print(f" Group {g}: w_g shape={w_g.shape} w_fp4 shape={w_fp4.shape} w_sf shape={w_sf.shape} gs={w_gs:.6f}")
|
|
|
|
# Step 2: test runner directly (bypass custom op)
|
|
runner = CuTeDSLNvfp4WoA(
|
|
n_local_groups=G, heads_per_group=HPG, head_dim=HD,
|
|
o_lora_rank=OR, max_num_tokens=8, device=DEVICE,
|
|
)
|
|
runner.set_bf16_weight(w)
|
|
runner.finalize_weights()
|
|
runner._ensure_initialized()
|
|
|
|
# Compute activation gs
|
|
with torch.no_grad():
|
|
_, _, gs = quantize_to_nvfp4(o_g[:, 0, :]) # use first group's activation
|
|
print(f"\nActivation gs from sample: {gs:.6f}")
|
|
print(f"Runner gs: {runner._activation_global_scale:.6f}")
|
|
|
|
runner._activation_global_scale = gs # use the right one
|
|
|
|
# Call _run_impl directly
|
|
with torch.no_grad():
|
|
z_out = runner._run_impl(o)
|
|
print(f"\nz_out shape={z_out.shape} amax={z_out.amax():.4f}")
|
|
print(f"z_out[0, 0, :8] = {z_out[0, 0, :8]}")
|
|
|
|
# Per-group comparison
|
|
for g in range(G):
|
|
cos = F.cosine_similarity(z_ref[:, g, :].flatten().unsqueeze(0).float(),
|
|
z_out[:, g, :].flatten().unsqueeze(0).float()).item()
|
|
print(f" Group {g}: cosine={cos:.6f} ref_amax={z_ref[:, g, :].amax():.4f} out_amax={z_out[:, g, :].amax():.4f}")
|
|
|
|
cos = F.cosine_similarity(z_ref.flatten().unsqueeze(0).float(), z_out.flatten().unsqueeze(0).float()).item()
|
|
print(f"\nOverall cosine={cos:.6f}")
|