#!/usr/bin/env python3 """Deep-dive: compare scale assembly byte-by-byte for expert 0.""" import os, sys, torch REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, REPO_ROOT) from cutedsl.bridge import quantize_to_nvfp4, assemble_scales_2d_side from cutedsl.kernel.moe.torch_scaled_grouped_mm import pad_and_swizzle_single, ceil_div from vllm.nvfp4_cutedsl import CuTeDSLMoERunner DEVICE = "cuda" num_experts = 3 hidden_size = 7168 runner = CuTeDSLMoERunner(num_experts, hidden_size, 3072, device=DEVICE) def rand_fp4(*shape): return torch.randint(0, 256, shape, dtype=torch.uint8, device=DEVICE).view(torch.float4_e2m1fn_x2) def rand_sf(*shape): return torch.rand(shape, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn) runner.prepare_weights_direct( [rand_fp4(3584, 3072*2) for _ in range(num_experts)], [rand_sf(3584//16, 3072*2) for _ in range(num_experts)], [0.1]*num_experts, [rand_fp4(1536, hidden_size) for _ in range(num_experts)], [rand_sf(1536//16, hidden_size) for _ in range(num_experts)], [0.1]*num_experts, ) runner._ensure_stacked() # 8 tokens, expert 0 gets 4, expert 1 gets 3, expert 2 gets 1 tokens_per_expert = [4, 3, 1] total = sum(tokens_per_expert) x = torch.randn(total, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 x_fp4, x_sf, x_igs = quantize_to_nvfp4(x) # Reference: assemble_scales_2d_side x_sf_parts = [x_sf[0:4], x_sf[4:7], x_sf[7:8]] ref = assemble_scales_2d_side(x_sf_parts) # Cudagraph-safe expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=DEVICE) expert_offsets[1:] = torch.tensor(tokens_per_expert, dtype=torch.int32).cumsum(0) cg = runner._assemble_scales_cudagraph_safe(x_sf, expert_offsets) print(f"ref shape: {ref.shape}, cg shape: {cg.shape}") # Compare expert 0's block (first 128 rows) ref_e0 = ref[:128].view(torch.uint8) cg_e0 = cg[:128].view(torch.uint8) diff_e0 = (ref_e0 != cg_e0).sum().item() print(f"Expert 0: {diff_e0}/{ref_e0.numel()} bytes differ") # Where do they differ? if diff_e0 > 0: diff_idx = torch.where(ref_e0.flatten() != cg_e0.flatten())[0] for i in diff_idx[:20]: print(f" byte {i}: ref={ref_e0.flatten()[i].item()}, cg={cg_e0.flatten()[i].item()}") # Also test: does pad_and_swizzle_single on the SAME input give the same output? buf = torch.zeros(128, 448, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn) buf[:4, :x_sf.shape[1]] = x_sf[0:4] swizzled_buf = pad_and_swizzle_single(buf) # Now compare swizzled_buf vs the reference path's expert 0 swizzled block # Reference path swizzles x_sf[0:4] padded to 128 rows buf2 = torch.zeros(4, 448, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn) buf2[:4, :x_sf.shape[1]] = x_sf[0:4] # Wait, assemble_scales_2d_side pads to 128 rows first buf3 = torch.zeros(128, 448, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn) buf3[:4, :x_sf.shape[1]] = x_sf[0:4] swizzled_ref = pad_and_swizzle_single(buf3) print(f"\nSame-input swizzle comparison: {torch.equal(swizzled_buf.view(torch.uint8), swizzled_ref.view(torch.uint8))}")