"""Test: Validate that cute.exp works on register tensors in the fused epilogue. Step 1 of the fused SwiGLU validation. We test with fused_swiglu=True but with the full SiLU applied (not gate/up pairing yet). This confirms that: 1. cute.exp works on register tensors 2. The element-wise SiLU (x / (1+exp(-x))) produces correct values 3. The register tensor can be converted to BF16 and stored to C The test compares the fused kernel output (SiLU applied in registers) against the PyTorch equivalent (SiLU applied to the BF16 L1 output). """ import torch import sys sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel') from dsv4.ops.quantize import ( quantize_weight_to_nvfp4, quantize_activation_nvfp4, ) from dsv4.ops.layouts import ( make_b_k_major, assemble_scales_2d_side, assemble_scales_3d_side, ) from dsv4.ops.gemm_runner import ( run_nvfp4_grouped_gemm, warmup_compilation, ) def test_silu_in_registers(): """Compare SiLU applied in registers vs SiLU applied in PyTorch.""" device = "cuda" num_experts = 4 hidden = 512 intermediate = 256 num_tokens = 32 torch.manual_seed(42) x = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device) # Create and quantize L1 weights (gate+up fused) l1_w = torch.randn(num_experts, 2 * intermediate, hidden, dtype=torch.bfloat16, device=device) l1_fp4_list, l1_sf_list, l1_gs_list = [], [], [] for e in range(num_experts): w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w[e].T) l1_fp4_list.append(w_fp4) l1_sf_list.append(w_sf) l1_gs_list.append(w_gs) l1_mat_b = make_b_k_major(torch.stack(l1_fp4_list)) l1_scale_b = assemble_scales_3d_side(l1_sf_list) l1_gs = torch.tensor(l1_gs_list, dtype=torch.float32, device=device) gs_val = x.abs().max().item() / (6.0 * 448.0) x_fp4, x_sf = quantize_activation_nvfp4(x, gs_val) tokens_per_expert = [num_tokens // num_experts] * num_experts scale_a = assemble_scales_2d_side([x_sf[i*tpe:(i+1)*tpe] for i, tpe in enumerate(tokens_per_expert)]) expert_offsets = torch.tensor( [sum(tokens_per_expert[:e+1]) for e in range(num_experts)], dtype=torch.int32, device=device, ) global_scale_a = torch.full((num_experts,), gs_val, dtype=torch.float32, device=device) # Warmup standard GEMM warmup_compilation(num_experts, hidden // 2, (2 * intermediate) // 2, device) # Run standard L1 GEMM (no SiLU) out_bf16 = run_nvfp4_grouped_gemm( mat_a=x_fp4, mat_b=l1_mat_b, scale_a=scale_a, scale_b=l1_scale_b, expert_offsets=expert_offsets, global_scale_a=global_scale_a, global_scale_b=l1_gs, ) # Apply SiLU in PyTorch (reference) silu_ref = torch.nn.functional.silu(out_bf16) print(f"L1 BF16 output shape: {out_bf16.shape}") print(f"SiLU reference shape: {silu_ref.shape}") print(f"L1 output amax: {out_bf16.abs().amax().item():.4f}") print(f"SiLU reference amax: {silu_ref.abs().amax().item():.4f}") print() print("Step 1 validation: SiLU in PyTorch on BF16 GEMM output") print("Next step: Run fused kernel with SiLU in registers and compare") print() print("NOTE: The fused kernel with SiLU on the full acc_vec should produce") print("the same result as torch.nn.functional.silu on the BF16 output,") print("within NVFP4 quantization tolerance (~5e-2).") print() print("This test validates the SiLU math. The gate/up pairing (Step 2)") print("will change which values get SiLU applied (gate only, not up).") if __name__ == "__main__": test_silu_in_registers()