test: verify GEMM output shape

This commit is contained in:
2026-06-02 08:41:22 +00:00
parent f01d3f3eac
commit 40fb49d670

View File

@@ -0,0 +1,66 @@
#!/usr/bin/env python3
"""Quick test: verify GEMM output shape for NVFP4."""
import torch
import sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
from dsv4.ops.gemm_runner import (
warmup_compilation, run_nvfp4_grouped_gemm,
)
from dsv4.ops.quantize import quantize_to_nvfp4, quantize_activation_nvfp4
from dsv4.ops.layouts import (
make_b_k_major, interleave_l1_weights,
pad_and_swizzle_single, ceil_div as cutedsl_ceil_div,
assemble_scales_3d_side,
)
device = "cuda:0"
K = 7168
N = 6144 # gate+up combined
K_packed = K // 2
N_packed = N // 2
# Create random data
x_bf16 = torch.randn(128, K, dtype=torch.bfloat16, device=device) * 0.1
w_bf16 = torch.randn(1, K, N, dtype=torch.bfloat16, device=device) * 0.1
# Quantize
_, _, x_gs = quantize_to_nvfp4(x_bf16)
x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, x_gs)
w_bf16_t = w_bf16.permute(0, 2, 1).contiguous()
w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16_t)
if w_fp4.dtype == torch.uint8:
w_fp4 = w_fp4.view(torch.float4_e2m1fn_x2)
w_fp4_il = interleave_l1_weights(w_fp4)
mat_b = make_b_k_major(w_fp4_il)
print(f"x_fp4 shape: {tuple(x_fp4.shape)} dtype={x_fp4.dtype}")
print(f"mat_b shape: {tuple(mat_b.shape)} dtype={mat_b.dtype}")
print(f"N_packed from mat_b.shape[2]: {mat_b.shape[2]}")
# Run GEMM
warmup_compilation(1, K_packed, N_packed, device)
padded_offsets = torch.tensor([128], dtype=torch.int32, device=device)
K_sf = cutedsl_ceil_div(K, 16)
padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
scale_a_buf = torch.zeros(128, padded_cols, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)
scale_a_buf[:128, :x_sf.shape[1]] = x_sf
scale_a = pad_and_swizzle_single(scale_a_buf).reshape(128, padded_cols)
scale_b = assemble_scales_3d_side(w_sf)
gsa = torch.full((1,), x_gs, dtype=torch.float32, device=device)
gsb = torch.full((1,), w_gs, dtype=torch.float32, device=device)
x_padded = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
x_padded.view(torch.uint8)[:128] = x_fp4.view(torch.uint8)
out = run_nvfp4_grouped_gemm(
mat_a=x_padded, mat_b=mat_b,
scale_a=scale_a, scale_b=scale_b,
expert_offsets=padded_offsets,
global_scale_a=gsa, global_scale_b=gsb,
)
print(f"\nGEMM output shape: {tuple(out.shape)} dtype={out.dtype}")
print(f"Expected: (128, {N}) BF16")
print(f"Got: (128, {out.shape[1]}) BF16")