test: verify GEMM output shape
This commit is contained in:
66
tests/unit/test_gemm_shape.py
Normal file
66
tests/unit/test_gemm_shape.py
Normal file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Quick test: verify GEMM output shape for NVFP4."""
|
||||
import torch
|
||||
import sys
|
||||
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
|
||||
|
||||
from dsv4.ops.gemm_runner import (
|
||||
warmup_compilation, run_nvfp4_grouped_gemm,
|
||||
)
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4, quantize_activation_nvfp4
|
||||
from dsv4.ops.layouts import (
|
||||
make_b_k_major, interleave_l1_weights,
|
||||
pad_and_swizzle_single, ceil_div as cutedsl_ceil_div,
|
||||
assemble_scales_3d_side,
|
||||
)
|
||||
|
||||
device = "cuda:0"
|
||||
K = 7168
|
||||
N = 6144 # gate+up combined
|
||||
K_packed = K // 2
|
||||
N_packed = N // 2
|
||||
|
||||
# Create random data
|
||||
x_bf16 = torch.randn(128, K, dtype=torch.bfloat16, device=device) * 0.1
|
||||
w_bf16 = torch.randn(1, K, N, dtype=torch.bfloat16, device=device) * 0.1
|
||||
|
||||
# Quantize
|
||||
_, _, x_gs = quantize_to_nvfp4(x_bf16)
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, x_gs)
|
||||
w_bf16_t = w_bf16.permute(0, 2, 1).contiguous()
|
||||
w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16_t)
|
||||
if w_fp4.dtype == torch.uint8:
|
||||
w_fp4 = w_fp4.view(torch.float4_e2m1fn_x2)
|
||||
w_fp4_il = interleave_l1_weights(w_fp4)
|
||||
mat_b = make_b_k_major(w_fp4_il)
|
||||
|
||||
print(f"x_fp4 shape: {tuple(x_fp4.shape)} dtype={x_fp4.dtype}")
|
||||
print(f"mat_b shape: {tuple(mat_b.shape)} dtype={mat_b.dtype}")
|
||||
print(f"N_packed from mat_b.shape[2]: {mat_b.shape[2]}")
|
||||
|
||||
# Run GEMM
|
||||
warmup_compilation(1, K_packed, N_packed, device)
|
||||
padded_offsets = torch.tensor([128], dtype=torch.int32, device=device)
|
||||
|
||||
K_sf = cutedsl_ceil_div(K, 16)
|
||||
padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
|
||||
scale_a_buf = torch.zeros(128, padded_cols, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)
|
||||
scale_a_buf[:128, :x_sf.shape[1]] = x_sf
|
||||
scale_a = pad_and_swizzle_single(scale_a_buf).reshape(128, padded_cols)
|
||||
scale_b = assemble_scales_3d_side(w_sf)
|
||||
|
||||
gsa = torch.full((1,), x_gs, dtype=torch.float32, device=device)
|
||||
gsb = torch.full((1,), w_gs, dtype=torch.float32, device=device)
|
||||
|
||||
x_padded = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
|
||||
x_padded.view(torch.uint8)[:128] = x_fp4.view(torch.uint8)
|
||||
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_padded, mat_b=mat_b,
|
||||
scale_a=scale_a, scale_b=scale_b,
|
||||
expert_offsets=padded_offsets,
|
||||
global_scale_a=gsa, global_scale_b=gsb,
|
||||
)
|
||||
print(f"\nGEMM output shape: {tuple(out.shape)} dtype={out.dtype}")
|
||||
print(f"Expected: (128, {N}) BF16")
|
||||
print(f"Got: (128, {out.shape[1]}) BF16")
|
||||
Reference in New Issue
Block a user