Files
nvfp4-megamoe-kernel/tests/unit/test_gemm_shape.py

63 lines
2.6 KiB
Python

#!/usr/bin/env python3
"""Verify GEMM output shape — use production weight format."""
import torch, sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
from dsv4.ops.gemm_runner import warmup_compilation, run_nvfp4_grouped_gemm
from dsv4.ops.quantize import quantize_to_nvfp4, quantize_activation_nvfp4
from dsv4.ops.layouts import (make_b_k_major, interleave_l1_weights,
pad_and_swizzle_single, ceil_div as cutedsl_ceil_div, assemble_scales_3d_side)
device = "cuda:0"
K = 7168; N = 6144 # gate+up
K_packed = K // 2; N_packed = N // 2
# Create weight in PRODUCTION format: (N, K) BF16 → quantize → (N_packed, K_packed) float4
torch.manual_seed(42)
w_bf16 = torch.randn(N, K, dtype=torch.bfloat16, device=device) * 0.1
w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16) # (N_packed, K_packed) float4
print(f"w_fp4 shape: {tuple(w_fp4.shape)} dtype={w_fp4.dtype}")
# Production path: (N_packed, K_packed) → (1, K_packed, N_packed) → interleave → make_b_k_major
if w_fp4.dtype == torch.uint8:
w_fp4 = w_fp4.view(torch.float4_e2m1fn_x2)
w_ekn = w_fp4.unsqueeze(0).permute(0, 2, 1).contiguous() # (1, K_packed, N_packed)
print(f"w_ekn shape (after permute): {tuple(w_ekn.shape)}")
w_ekn = interleave_l1_weights(w_ekn)
mat_b = make_b_k_major(w_ekn)
print(f"mat_b shape: {tuple(mat_b.shape)} dtype={mat_b.dtype}")
# Activation
x_bf16 = torch.randn(128, K, dtype=torch.bfloat16, device=device) * 0.1
_, _, x_gs = quantize_to_nvfp4(x_bf16)
x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, x_gs)
# Warmup
warmup_compilation(1, K_packed, N_packed, device)
# Scales
padded_offsets = torch.tensor([128], dtype=torch.int32, device=device)
K_sf = cutedsl_ceil_div(K, 16)
padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
scale_a_buf = torch.zeros(128, padded_cols, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)
scale_a_buf[:128, :x_sf.shape[1]] = x_sf
scale_a = pad_and_swizzle_single(scale_a_buf).reshape(128, padded_cols)
scale_b = assemble_scales_3d_side([w_sf])
gsa = torch.full((1,), x_gs, dtype=torch.float32, device=device)
gsb = torch.full((1,), w_gs, dtype=torch.float32, device=device)
# Pad activation
x_padded = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
x_padded.view(torch.uint8)[:128] = x_fp4.view(torch.uint8)
out = run_nvfp4_grouped_gemm(
mat_a=x_padded, mat_b=mat_b,
scale_a=scale_a, scale_b=scale_b,
expert_offsets=padded_offsets,
global_scale_a=gsa, global_scale_b=gsb,
)
print(f"\nGEMM output: shape={tuple(out.shape)} dtype={out.dtype}")
print(f"N (BF16) = {N}, N_packed = {N_packed}")
print(f"n_dim = mat_b.shape[2] = {mat_b.shape[2]}")
print(f"Output columns = {out.shape[1]}")