auto: pre-test push for test_gemm_1group.py
This commit is contained in:
56
test_gemm_1group.py
Normal file
56
test_gemm_1group.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test: run_nvfp4_grouped_gemm with num_groups=1 on different GPUs."""
|
||||
import torch
|
||||
from dsv4.ops.gemm_runner import run_nvfp4_grouped_gemm, warmup_nvfp4_compilation
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu
|
||||
from dsv4.ops.layouts import make_b_k_major, assemble_scales_3d_side
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create a simple BF16 weight and quantize
|
||||
M, N, K = 1, 3072, 7168
|
||||
|
||||
for gpu in [0, 1]:
|
||||
torch.cuda.set_device(gpu)
|
||||
dev = f"cuda:{gpu}"
|
||||
|
||||
# Weight
|
||||
w = torch.randn(N, K, dtype=torch.bfloat16, device=dev)
|
||||
from dsv4.ops.quantize import quantize_weight_to_nvfp4
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w)
|
||||
gsb = torch.tensor([1.0], dtype=torch.float32, device=dev) # simplified gsb
|
||||
|
||||
# K-major layout
|
||||
w_km = make_b_k_major(w_fp4.unsqueeze(0)) # (1, K_sf, N)
|
||||
w_sf_3d = assemble_scales_3d_side(w_sf.unsqueeze(0)) # (1, K_sf_padded, N)
|
||||
|
||||
# Activation
|
||||
x = torch.randn(M, K, dtype=torch.bfloat16, device=dev)
|
||||
gsa = 1.0 / (6.0 * 448.0)
|
||||
x_fp4, x_sf = quantize_nvfp4_gpu(x, gsa)
|
||||
|
||||
# Expert offsets
|
||||
padded_rows = 128
|
||||
expert_offsets = torch.tensor([padded_rows], dtype=torch.int32, device=dev)
|
||||
|
||||
# Output
|
||||
out = torch.zeros(padded_rows, N, dtype=torch.bfloat16, device=dev)
|
||||
|
||||
# Global scales
|
||||
gsa_buf = torch.tensor([gsa], dtype=torch.float32, device=dev)
|
||||
|
||||
# Run GEMM with 1 group
|
||||
run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4[:padded_rows],
|
||||
scale_a=x_sf[:padded_rows],
|
||||
mat_b=w_km,
|
||||
scale_b=w_sf_3d,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=gsa_buf,
|
||||
global_scale_b=gsb,
|
||||
out=out,
|
||||
num_groups=1,
|
||||
)
|
||||
|
||||
has_nan = torch.isnan(out[:M]).any().item()
|
||||
print(f"GPU {gpu}: |out|={out[:M].abs().max().item() if not has_nan else 'NaN'} has_nan={has_nan} shape={out[:M].shape}")
|
||||
Reference in New Issue
Block a user