#!/usr/bin/env python3 """ CuTeDSL NVFP4 GEMM test — verify the reference kernel works with our data. Uses NVIDIA's ScaledGroupedGemmKernel from the CUTLASS CuTeDSL examples with NVFP4 (Float4E2M1FN + Float8E4M3FN, sf_vec_size=16). This tests a single GEMM: A(tokens, K) @ B(experts, K, N) = C(tokens, N) with proper scale factor padding/swizzling. """ import sys import os import math import torch # Add repo root so 'from cutedsl.kernel...' works REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, REPO_ROOT) import cutlass import cutlass.cute as cute import cutlass.torch as cutlass_torch import cutlass.utils as utils import cutlass.utils.blockscaled_layout as blockscaled_utils from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( ScaledGroupedGemmKernel, pad_and_swizzle_single, assemble_raw_scales_2d3d_2d_side, assemble_raw_scales_2d3d_3d_side, cat_byte_reinterpretable_tensors, stack_byte_reinterpretable_tensors, offs_to_group_sizes, ) # ── Helpers ──────────────────────────────────────────────────────────── def ceil_div(a, b): return (a + b - 1) // b def round_up(a, b): return ceil_div(a, b) * b def quantize_bf16_to_nvfp4(x_bf16, block_size=16): """Quantize BF16 tensor to NVFP4 (E2M1 + E4M3 block scales + global scale). Returns (x_fp4, block_scales, global_scale) where: x_fp4: torch.float4_e2m1fn_x2 with same logical shape (packed along last dim) block_scales: float8_e4m3fn with shape (..., ceil_div(last_dim, block_size)) global_scale: float32 scalar """ x_f32 = x_bf16.float() amax = x_f32.abs().max().clamp(min=1e-8).float() global_scale = amax / (6.0 * 448.0) x_norm = x_f32 / global_scale # Per-block amax for block scales last_dim = x_norm.shape[-1] n_blocks = ceil_div(last_dim, block_size) # Pad last dim to multiple of block_size if last_dim % block_size != 0: pad_size = n_blocks * block_size - last_dim x_norm = torch.nn.functional.pad(x_norm, (0, pad_size)) x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size) block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) # Quantize to E2M1 E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] x_blocks = x_reshaped block_sf_expanded = block_scale.float().unsqueeze(-1) x_scaled = x_blocks / block_sf_expanded.clamp(min=1e-8) magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=x_bf16.device) signs = torch.sign(x_scaled) abs_scaled = x_scaled.abs().unsqueeze(-1) distances = (abs_scaled - magnitudes).abs() indices = distances.argmin(dim=-1) nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8) # Pack pairs: byte = (odd_nibble << 4) | even_nibble even = nibbles[..., ::2] odd = nibbles[..., 1::2] packed = (odd << 4) | even # View as float4_e2m1fn_x2 — same logical shape, packed last dim halved # The logical shape has the original last_dim, but stored packed # float4_e2m1fn_x2: each element is 1 byte = 2 FP4 values # Shape: (..., last_dim // 2) in float4_e2m1fn_x2 packed_shape = list(x_bf16.shape) packed_shape[-1] = last_dim // 2 x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape) # Reshape block scales sf_shape = list(x_bf16.shape[:-1]) + [n_blocks] block_scale = block_scale.reshape(sf_shape) return x_fp4, block_scale, global_scale def dequantize_nvfp4(x_fp4, block_scales, global_scale): """Dequantize NVFP4 back to BF16 for reference comparison.""" E2M1_LUT = torch.tensor([ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ], dtype=torch.float32, device=x_fp4.device) raw = x_fp4.view(torch.uint8) lo = E2M1_LUT[(raw & 0x0F).long()] hi = E2M1_LUT[((raw >> 4) & 0x0F).long()] unpacked = torch.empty(*raw.shape[:-1], raw.shape[-1] * 2, dtype=torch.float32, device=x_fp4.device) unpacked[..., ::2] = lo unpacked[..., 1::2] = hi # Expand block scales n_blocks = block_scales.shape[-1] block_size = (unpacked.shape[-1]) // n_blocks block_sf = block_scales.float().unsqueeze(-1).expand(*block_scales.shape, block_size) block_sf = block_sf.reshape(*unpacked.shape) return (unpacked * block_sf * global_scale).to(torch.bfloat16) # ── Main Test ────────────────────────────────────────────────────────── def main(): torch.manual_seed(42) device = "cuda" # Problem sizes num_experts = 2 tokens_per_expert = 64 hidden = 256 # K dimension intermediate = 128 # N dimension sf_vec_size = 16 block_size = 16 tokens_sum = num_experts * tokens_per_expert print(f"Test: {num_experts} experts, {tokens_per_expert} tokens each, K={hidden}, N={intermediate}") # ── Create BF16 reference data ── x_bf16 = torch.randn(tokens_sum, hidden, dtype=torch.bfloat16, device=device) * 2.0 w_bf16 = torch.randn(num_experts, hidden, intermediate, dtype=torch.bfloat16, device=device) * 0.5 # BF16 reference: for each expert, matmul its tokens with its weight ref_out = torch.zeros(tokens_sum, intermediate, dtype=torch.bfloat16, device=device) for e in range(num_experts): start = e * tokens_per_expert end = (e + 1) * tokens_per_expert ref_out[start:end] = x_bf16[start:end] @ w_bf16[e] print(f"BF16 ref: amax={ref_out.abs().max():.4f} mean={ref_out.float().mean():.6f}") # ── Quantize to NVFP4 ── x_fp4, x_sf, x_gs = quantize_bf16_to_nvfp4(x_bf16) # For weights: the kernel expects (experts, hidden, intermediate) with # packed_dim=1 (the hidden/K dimension is packed). # w_bf16[e] is (hidden, intermediate). # We quantize each expert weight, keeping the packed dim as hidden. w_fp4_list, w_sf_list, w_gs_list = [], [], [] for e in range(num_experts): w = w_bf16[e] # (hidden, intermediate) — K=hidden, N=intermediate w_f32 = w.float() w_amax = w_f32.abs().max().clamp(min=1e-8).float() w_gs = w_amax / (6.0 * 448.0) w_norm = w_f32 / w_gs # Block scales along the K dimension (dim 0 = hidden) # Scale shape: (ceil_div(hidden, 16), intermediate) k_blocks = ceil_div(hidden, block_size) if hidden % block_size != 0: w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - hidden)) w_reshaped = w_norm.reshape(k_blocks, block_size, intermediate) w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8) # (k_blocks, intermediate) w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn) # Quantize to E2M1 along K (dim 0) E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] w_block_sf = w_sf.float().unsqueeze(1) # (k_blocks, 1, intermediate) w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8) magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=device) signs = torch.sign(w_scaled) abs_scaled = w_scaled.abs().unsqueeze(-1) distances = (abs_scaled - magnitudes).abs() indices = distances.argmin(dim=-1) nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8) # Pack pairs along K (block_size dim, which is dim 1 after reshape) even = nibbles[:, ::2, :] odd = nibbles[:, 1::2, :] packed = (odd << 4) | even # (k_blocks, block_size//2, intermediate) # Reshape to (hidden//2, intermediate) in float4_e2m1fn_x2 w_fp4 = packed.reshape(hidden // 2, intermediate).view(torch.float4_e2m1fn_x2) w_fp4_list.append(w_fp4) w_sf_list.append(w_sf) # (k_blocks, intermediate) = (hidden//16, intermediate) w_gs_list.append(w_gs) # Verify quantization roundtrip x_deq = dequantize_nvfp4(x_fp4, x_sf, x_gs) cos_quant = torch.nn.functional.cosine_similarity( x_bf16.flatten().unsqueeze(0).float(), x_deq.flatten().unsqueeze(0).float(), ).item() print(f"Quantization roundtrip cosine: {cos_quant:.6f}") # ── Prepare CuTeDSL kernel inputs ── # The kernel expects: # mat_a: (tokens_sum, K_packed) float4_e2m1fn_x2 # mat_b: (experts, K_packed, N_packed) float4_e2m1fn_x2 — K-major # scale_a: assembled 2D side (padded + swizzled) # scale_b: assembled 3D side (padded + swizzled per expert) # offs: (experts,) int32 cumulative token offsets # global_scale_a: (experts,) float32 # global_scale_b: (experts,) float32 # Expert offsets (cumulative sum of tokens per expert) offs = torch.tensor([tokens_per_expert * (e + 1) for e in range(num_experts)], dtype=torch.int32, device=device) # Assemble scale_a (2D side: concatenate per-expert, pad to 128, swizzle) raw_scale_a = [x_sf[e*tokens_per_expert:(e+1)*tokens_per_expert] for e in range(num_experts)] scale_a = assemble_raw_scales_2d3d_2d_side(raw_scale_a) # Assemble scale_b (3D side: per-expert, pad and swizzle each) # Reference uses (N, K_sf) = (intermediate, hidden//16) for each expert # Our w_sf is (K_sf, intermediate) — need to transpose w_sf_t = [sf.T.contiguous() for sf in w_sf_list] scale_b = assemble_raw_scales_2d3d_3d_side(w_sf_t) # Global scales global_scale_a = torch.tensor([x_gs] * num_experts, dtype=torch.float32, device=device) global_scale_b = torch.tensor([w_gs_list[e] for e in range(num_experts)], dtype=torch.float32, device=device) # mat_a is (tokens_sum, K_packed) in float4_e2m1fn_x2, row-major (K-major) # This matches the reference: A shape=(128,128) stride=(128,1) mat_a = x_fp4 # mat_b: (experts, K_packed, N_packed) in float4_e2m1fn_x2, K-major # Reference: B shape=(2,128,128) stride=(16384,1,128) — K is stride-1 # torch.stack gives stride (16384, 128, 1) — N is stride-1 (wrong) # We need K-major: permute, make contiguous, permute back mat_b = torch.stack(w_fp4_list).permute(0, 2, 1).contiguous().permute(0, 2, 1) print(f"\nKernel inputs:") print(f" mat_a: {mat_a.shape} {mat_a.dtype}") print(f" mat_b: {mat_b.shape} {mat_b.dtype}") print(f" scale_a: {scale_a.shape} {scale_a.dtype}") print(f" scale_b: {scale_b.shape} {scale_b.dtype}") print(f" offs: {offs.tolist()}") print(f" global_scale_a: {global_scale_a.tolist()}") print(f" global_scale_b: {[f'{v:.6e}' for v in global_scale_b.tolist()]}") # ── Run CuTeDSL kernel ── print("\nCompiling and running CuTeDSL kernel (first run takes ~1 min to compile)...") out = torch.zeros(tokens_sum, intermediate, dtype=torch.bfloat16, device=device) kernel = ScaledGroupedGemmKernel( scenario="2Dx3D", sf_vec_size=sf_vec_size, accumulate_on_output=False, separate_tensormap_init=True, consistent_token_padding=False, mma_tiler_mnk=(128, 128, 256), cluster_shape_mnk=(1, 1, 1), ) # Convert to CuTe tensors a_cute = cutlass_torch.from_dlpack(mat_a) a_cute = a_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(mat_a)) b_cute = cutlass_torch.from_dlpack(mat_b) b_cute = b_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(mat_b)) sfa_cute = cutlass_torch.from_dlpack(scale_a) sfa_cute = sfa_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(scale_a)) sfb_cute = cutlass_torch.from_dlpack(scale_b) sfb_cute = sfb_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(scale_b)) c_cute = cutlass_torch.from_dlpack(out) c_cute = c_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(out)) offs_cute = cutlass_torch.from_dlpack(offs) offs_cute = offs_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(offs)) workspace_size = kernel.get_workspace_size(num_experts) workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device) ws_cute = cutlass_torch.from_dlpack(workspace) ws_cute = ws_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(workspace)) gsa_cute = cutlass_torch.from_dlpack(global_scale_a) gsa_cute = gsa_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(global_scale_a)) gsb_cute = cutlass_torch.from_dlpack(global_scale_b) gsb_cute = gsb_cute.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(global_scale_b)) import cuda.bindings.driver as cuda cluster_size = 1 max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) compiled = cute.compile( kernel, a_cute, b_cute, sfa_cute, sfb_cute, c_cute, offs_cute, ws_cute, max_active_clusters, stream, global_scale_a=gsa_cute, global_scale_b=gsb_cute, ) compiled( a_cute, b_cute, sfa_cute, sfb_cute, c_cute, offs_cute, ws_cute, stream, global_scale_a=gsa_cute, global_scale_b=gsb_cute, ) torch.cuda.synchronize() # ── Compare ── cosine = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0).float(), ref_out.flatten().unsqueeze(0).float(), ).item() mse = (out.float() - ref_out.float()).pow(2).mean().item() print(f"\n{'='*70}") print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}") print(f"{'='*70}") if cosine > 0.99: print(f" ✅ PASS: CuTeDSL kernel matches BF16 reference") elif cosine > 0.95: print(f" ⚠️ Close but not perfect — quantization loss?") else: print(f" ❌ FAIL: kernel output doesn't match reference") if __name__ == "__main__": main()