From 793f062bbcf6d9ad3ca004bf7de330923e38bdd8 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 04:32:29 +0000 Subject: [PATCH] auto: pre-test push for test_gemm_1group.py --- test_gemm_1group.py | 56 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 test_gemm_1group.py diff --git a/test_gemm_1group.py b/test_gemm_1group.py new file mode 100644 index 00000000..e75698be --- /dev/null +++ b/test_gemm_1group.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +"""Test: run_nvfp4_grouped_gemm with num_groups=1 on different GPUs.""" +import torch +from dsv4.ops.gemm_runner import run_nvfp4_grouped_gemm, warmup_nvfp4_compilation +from dsv4.ops.quantize import quantize_nvfp4_gpu +from dsv4.ops.layouts import make_b_k_major, assemble_scales_3d_side + +torch.manual_seed(42) + +# Create a simple BF16 weight and quantize +M, N, K = 1, 3072, 7168 + +for gpu in [0, 1]: + torch.cuda.set_device(gpu) + dev = f"cuda:{gpu}" + + # Weight + w = torch.randn(N, K, dtype=torch.bfloat16, device=dev) + from dsv4.ops.quantize import quantize_weight_to_nvfp4 + w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w) + gsb = torch.tensor([1.0], dtype=torch.float32, device=dev) # simplified gsb + + # K-major layout + w_km = make_b_k_major(w_fp4.unsqueeze(0)) # (1, K_sf, N) + w_sf_3d = assemble_scales_3d_side(w_sf.unsqueeze(0)) # (1, K_sf_padded, N) + + # Activation + x = torch.randn(M, K, dtype=torch.bfloat16, device=dev) + gsa = 1.0 / (6.0 * 448.0) + x_fp4, x_sf = quantize_nvfp4_gpu(x, gsa) + + # Expert offsets + padded_rows = 128 + expert_offsets = torch.tensor([padded_rows], dtype=torch.int32, device=dev) + + # Output + out = torch.zeros(padded_rows, N, dtype=torch.bfloat16, device=dev) + + # Global scales + gsa_buf = torch.tensor([gsa], dtype=torch.float32, device=dev) + + # Run GEMM with 1 group + run_nvfp4_grouped_gemm( + mat_a=x_fp4[:padded_rows], + scale_a=x_sf[:padded_rows], + mat_b=w_km, + scale_b=w_sf_3d, + expert_offsets=expert_offsets, + global_scale_a=gsa_buf, + global_scale_b=gsb, + out=out, + num_groups=1, + ) + + has_nan = torch.isnan(out[:M]).any().item() + print(f"GPU {gpu}: |out|={out[:M].abs().max().item() if not has_nan else 'NaN'} has_nan={has_nan} shape={out[:M].shape}")