Files
nvfp4-megamoe-kernel/tests/debug_wo_a.py
2026-05-19 02:39:55 +00:00

42 lines
1.2 KiB
Python

"""Minimal debug: verify wo_a grouped matmul reference is correct."""
import torch
import torch.nn.functional as F
torch.cuda.set_device(0)
torch.manual_seed(42)
# Small dimensions for debugging
G, HPG, HD, OR = 2, 4, 128, 64
GI = HPG * HD
T = 4
DEVICE = "cuda:0"
o = torch.randn(T, G*HPG, HD, dtype=torch.bfloat16, device=DEVICE) * 2.0
w = torch.randn(G*OR, GI, dtype=torch.bfloat16, device=DEVICE) * 0.1
# Reference: per-group matmul
o_g = o.reshape(T, G, GI)
z_ref = torch.empty(T, G, OR, dtype=torch.bfloat16, device=DEVICE)
for g in range(G):
z_ref[:, g, :] = o_g[:, g, :] @ w[g*OR:(g+1)*OR, :].T
print(f"z_ref shape={z_ref.shape} amax={z_ref.amax():.4f}")
# Now test the CuTeDSL runner
from cutedsl.wo_a_grouped_linear import CuTeDSLNvfp4WoA
runner = CuTeDSLNvfp4WoA(
n_local_groups=G, heads_per_group=HPG, head_dim=HD,
o_lora_rank=OR, max_num_tokens=8, device=DEVICE,
)
runner.set_bf16_weight(w)
runner.finalize_weights()
runner._ensure_initialized()
runner.compute_activation_global_scale(o)
with torch.no_grad():
z_out = runner.run(o)
cos = F.cosine_similarity(z_ref.flatten().unsqueeze(0).float(), z_out.flatten().unsqueeze(0).float()).item()
print(f"cosine={cos:.6f} amax_ref={z_ref.amax():.4f} amax_out={z_out.amax():.4f}")