#!/usr/bin/env python3 """Verify GEMM output shape — use production weight format.""" import torch, sys sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel') from dsv4.ops.gemm_runner import warmup_compilation, run_nvfp4_grouped_gemm from dsv4.ops.quantize import quantize_to_nvfp4, quantize_activation_nvfp4 from dsv4.ops.layouts import (make_b_k_major, interleave_l1_weights, pad_and_swizzle_single, ceil_div as cutedsl_ceil_div, assemble_scales_3d_side) device = "cuda:0" K = 7168; N = 6144 # gate+up K_packed = K // 2; N_packed = N // 2 # Create weight in PRODUCTION format: (N, K) BF16 → quantize → (N_packed, K_packed) float4 torch.manual_seed(42) w_bf16 = torch.randn(N, K, dtype=torch.bfloat16, device=device) * 0.1 w_fp4, w_sf, w_gs = quantize_to_nvfp4(w_bf16) # (N_packed, K_packed) float4 print(f"w_fp4 shape: {tuple(w_fp4.shape)} dtype={w_fp4.dtype}") # Production path: (N_packed, K_packed) → (1, K_packed, N_packed) → interleave → make_b_k_major if w_fp4.dtype == torch.uint8: w_fp4 = w_fp4.view(torch.float4_e2m1fn_x2) w_ekn = w_fp4.unsqueeze(0).permute(0, 2, 1).contiguous() # (1, K_packed, N_packed) print(f"w_ekn shape (after permute): {tuple(w_ekn.shape)}") w_ekn = interleave_l1_weights(w_ekn) mat_b = make_b_k_major(w_ekn) print(f"mat_b shape: {tuple(mat_b.shape)} dtype={mat_b.dtype}") # Activation x_bf16 = torch.randn(128, K, dtype=torch.bfloat16, device=device) * 0.1 _, _, x_gs = quantize_to_nvfp4(x_bf16) x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, x_gs) # Warmup warmup_compilation(1, K_packed, N_packed, device) # Scales padded_offsets = torch.tensor([128], dtype=torch.int32, device=device) K_sf = cutedsl_ceil_div(K, 16) padded_cols = cutedsl_ceil_div(K_sf, 4) * 4 scale_a_buf = torch.zeros(128, padded_cols, dtype=torch.float16, device=device).to(torch.float8_e4m3fn) scale_a_buf[:128, :x_sf.shape[1]] = x_sf scale_a = pad_and_swizzle_single(scale_a_buf).reshape(128, padded_cols) scale_b = assemble_scales_3d_side([w_sf]) gsa = torch.full((1,), x_gs, dtype=torch.float32, device=device) gsb = torch.full((1,), w_gs, dtype=torch.float32, device=device) # Pad activation x_padded = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) x_padded.view(torch.uint8)[:128] = x_fp4.view(torch.uint8) out = run_nvfp4_grouped_gemm( mat_a=x_padded, mat_b=mat_b, scale_a=scale_a, scale_b=scale_b, expert_offsets=padded_offsets, global_scale_a=gsa, global_scale_b=gsb, ) print(f"\nGEMM output: shape={tuple(out.shape)} dtype={out.dtype}") print(f"N (BF16) = {N}, N_packed = {N_packed}") print(f"n_dim = mat_b.shape[2] = {mat_b.shape[2]}") print(f"Output columns = {out.shape[1]}")