"""Test: uniform FP4 + uniform SF, different from all-ones. If all E2M1 values are the same (e.g. value 3 = 1.5) and all SF=1.0, then x = 1.5 for all elements, w = 1.5 for all elements. GEMM output = (1.5^2) * K = 2.25 * 32 = 72.0 for every element. """ import torch, sys sys.path.insert(0, 'src') from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm device = "cuda" M, N, K = 1, 32, 32 # Create packed FP4 where every nibble = 3 (E2M1 value 1.5) # Packing: (nibbles[..., 1] << 4) | nibbles[..., 0] # For both nibbles = 3: byte = (3 << 4) | 3 = 0x33 byte_val = (3 << 4) | 3 # 0x33 x_fp4 = torch.full((M, K // 2), byte_val, dtype=torch.int8, device=device) w_fp4 = torch.full((K // 2, N), byte_val, dtype=torch.int8, device=device) # Uniform SF = 1.0 x_sf = torch.ones(M, K // 16, dtype=torch.float8_e4m3fn, device=device) w_sf = torch.ones(K // 16, N, dtype=torch.float8_e4m3fn, device=device) out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0) # Reference: all x = 1.5, all w = 1.5, output = 1.5 * 1.5 * 32 = 72.0 print(f"NVFP4 output first 8: {out[0, :8].tolist()}") print(f"Expected: 72.0 for all elements") print(f"Actual mean: {out.float().mean().item():.4f}") print(f"All same? {torch.allclose(out, out[0,0].expand_as(out), atol=0.01)}")