diff --git a/tests/unit/test_gemm_shape.py b/tests/unit/test_gemm_shape.py new file mode 100644 index 00000000..e1be240c --- /dev/null +++ b/tests/unit/test_gemm_shape.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +"""Quick test: verify GEMM output shape for NVFP4.""" +import torch +import sys +sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel') + +from dsv4.ops.gemm_runner import ( + warmup_compilation, run_nvfp4_grouped_gemm, +) +from dsv4.ops.quantize import quantize_to_nvfp4, quantize_activation_nvfp4 +from dsv4.ops.layouts import ( + make_b_k_major, interleave_l1_weights, + pad_and_swizzle_single, ceil_div as cutedsl_ceil_div, + assemble_scales_3d_side, +) + +device = "cuda:0" +K = 7168 +N = 6144 # gate+up combined +K_packed = K // 2 +N_packed = N // 2 + +# Create random data +x_bf16 = torch.randn(128, K, dtype=torch.bfloat16, device=device) * 0.1 +w_bf16 = torch.randn(1, K, N, dtype=torch.bfloat16, device=device) * 0.1 + +# Quantize +_, _, x_gs = quantize_to_nvfp4(x_bf16) +x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, x_gs) +w_bf16_t = w_bf16.permute(0, 2, 1).contiguous() +w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16_t) +if w_fp4.dtype == torch.uint8: + w_fp4 = w_fp4.view(torch.float4_e2m1fn_x2) +w_fp4_il = interleave_l1_weights(w_fp4) +mat_b = make_b_k_major(w_fp4_il) + +print(f"x_fp4 shape: {tuple(x_fp4.shape)} dtype={x_fp4.dtype}") +print(f"mat_b shape: {tuple(mat_b.shape)} dtype={mat_b.dtype}") +print(f"N_packed from mat_b.shape[2]: {mat_b.shape[2]}") + +# Run GEMM +warmup_compilation(1, K_packed, N_packed, device) +padded_offsets = torch.tensor([128], dtype=torch.int32, device=device) + +K_sf = cutedsl_ceil_div(K, 16) +padded_cols = cutedsl_ceil_div(K_sf, 4) * 4 +scale_a_buf = torch.zeros(128, padded_cols, dtype=torch.float16, device=device).to(torch.float8_e4m3fn) +scale_a_buf[:128, :x_sf.shape[1]] = x_sf +scale_a = pad_and_swizzle_single(scale_a_buf).reshape(128, padded_cols) +scale_b = assemble_scales_3d_side(w_sf) + +gsa = torch.full((1,), x_gs, dtype=torch.float32, device=device) +gsb = torch.full((1,), w_gs, dtype=torch.float32, device=device) + +x_padded = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) +x_padded.view(torch.uint8)[:128] = x_fp4.view(torch.uint8) + +out = run_nvfp4_grouped_gemm( + mat_a=x_padded, mat_b=mat_b, + scale_a=scale_a, scale_b=scale_b, + expert_offsets=padded_offsets, + global_scale_a=gsa, global_scale_b=gsb, +) +print(f"\nGEMM output shape: {tuple(out.shape)} dtype={out.dtype}") +print(f"Expected: (128, {N}) BF16") +print(f"Got: (128, {out.shape[1]}) BF16")