"""Minimal debug: verify wo_a grouped matmul reference is correct.""" import torch import torch.nn.functional as F torch.cuda.set_device(0) torch.manual_seed(42) # Small dimensions for debugging G, HPG, HD, OR = 2, 4, 128, 64 GI = HPG * HD T = 4 DEVICE = "cuda:0" o = torch.randn(T, G*HPG, HD, dtype=torch.bfloat16, device=DEVICE) * 2.0 w = torch.randn(G*OR, GI, dtype=torch.bfloat16, device=DEVICE) * 0.1 # Reference: per-group matmul o_g = o.reshape(T, G, GI) z_ref = torch.empty(T, G, OR, dtype=torch.bfloat16, device=DEVICE) for g in range(G): z_ref[:, g, :] = o_g[:, g, :] @ w[g*OR:(g+1)*OR, :].T print(f"z_ref shape={z_ref.shape} amax={z_ref.amax():.4f}") # Now test the CuTeDSL runner from cutedsl.wo_a_grouped_linear import CuTeDSLNvfp4WoA runner = CuTeDSLNvfp4WoA( n_local_groups=G, heads_per_group=HPG, head_dim=HD, o_lora_rank=OR, max_num_tokens=8, device=DEVICE, ) runner.set_bf16_weight(w) runner.finalize_weights() runner._ensure_initialized() runner.compute_activation_global_scale(o) with torch.no_grad(): z_out = runner.run(o) cos = F.cosine_similarity(z_ref.flatten().unsqueeze(0).float(), z_out.flatten().unsqueeze(0).float()).item() print(f"cosine={cos:.6f} amax_ref={z_ref.amax():.4f} amax_out={z_out.amax():.4f}")