#!/usr/bin/env python3 """ Test B: Compare _assemble_scales_cudagraph_safe vs assemble_scales_2d_side. Both should produce identical output given the same x_sf and expert_offsets. If they differ, the cudagraph-safe path has a bug. Runs on the B200 host: source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate python3 tests/test_scale_assembly.py """ import os, sys, torch REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, REPO_ROOT) from cutedsl.bridge import quantize_to_nvfp4, assemble_scales_2d_side from cutedsl.kernel.moe.torch_scaled_grouped_mm import pad_and_swizzle_single, ceil_div from vllm.nvfp4_cutedsl import CuTeDSLMoERunner def test_scale_assembly(): """Compare the two scale assembly methods with realistic data.""" DEVICE = "cuda" num_experts = 3 hidden_size = 7168 intermediate_size = 3072 # Create a runner just to use its _assemble_scales_cudagraph_safe runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE) # Trigger _ensure_stacked and buffer allocation with dummy weights def rand_fp4(*shape): return torch.randint(0, 256, shape, dtype=torch.uint8, device=DEVICE).view(torch.float4_e2m1fn_x2) def rand_sf(*shape): return torch.rand(shape, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn) runner.prepare_weights_direct( [rand_fp4(3584, intermediate_size * 2) for _ in range(num_experts)], [rand_sf(3584 // 16, intermediate_size * 2) for _ in range(num_experts)], [0.1] * num_experts, [rand_fp4(1536, hidden_size) for _ in range(num_experts)], [rand_sf(1536 // 16, hidden_size) for _ in range(num_experts)], [0.1] * num_experts, ) runner._ensure_stacked() # Test with different token distributions test_cases = [ ("4 tokens, expert 0 gets 2, expert 1 gets 2, expert 2 gets 0", [2, 2, 0]), ("8 tokens, expert 0 gets 4, expert 1 gets 3, expert 2 gets 1", [4, 3, 1]), ("4 tokens, expert 0 gets 4, expert 1 gets 0, expert 2 gets 0", [4, 0, 0]), ("3 tokens, expert 0 gets 1, expert 1 gets 1, expert 2 gets 1", [1, 1, 1]), ] all_pass = True for desc, tokens_per_expert in test_cases: total_tokens = sum(tokens_per_expert) # Create input and quantize x = torch.randn(total_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 x_fp4, x_sf, x_igs = quantize_to_nvfp4(x) # Path 1: assemble_scales_2d_side (per-expert split) x_sf_parts = [] offset = 0 for tpe in tokens_per_expert: x_sf_parts.append(x_sf[offset:offset + tpe]) offset += tpe scale_a_ref = assemble_scales_2d_side(x_sf_parts) # Path 2: _assemble_scales_cudagraph_safe (GPU-only) expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=DEVICE) expert_offsets[1:] = torch.tensor(tokens_per_expert, dtype=torch.int32).cumsum(0) scale_a_cudagraph = runner._assemble_scales_cudagraph_safe( x_sf, expert_offsets, runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1 ) # Compare # Note: shapes may differ due to padding, but the data in the # padded rows should match (up to the total number of rows used by the kernel) if scale_a_ref.shape != scale_a_cudagraph.shape: print(f" {desc}") print(f" Shape mismatch: ref={scale_a_ref.shape}, cg={scale_a_cudagraph.shape}") all_pass = False continue match = torch.equal(scale_a_ref, scale_a_cudagraph) if not match: # Check how many bytes differ diff = (scale_a_ref.view(torch.uint8) != scale_a_cudagraph.view(torch.uint8)).sum().item() total = scale_a_ref.numel() pct = diff / total * 100 print(f" {desc}") print(f" MISMATCH: {diff}/{total} bytes differ ({pct:.1f}%)") print(f" ref range: [{scale_a_ref.view(torch.uint8).min()}, {scale_a_ref.view(torch.uint8).max()}]") print(f" cg range: [{scale_a_cudagraph.view(torch.uint8).min()}, {scale_a_cudagraph.view(torch.uint8).max()}]") all_pass = False else: print(f" {desc}: ✅ MATCH") print(f"\n{'=' * 70}") if all_pass: print(" ALL SCALE ASSEMBLY TESTS PASSED ✅") else: print(" SCALE ASSEMBLY TESTS FAILED ❌") print(f"{'=' * 70}") return all_pass if __name__ == "__main__": test_scale_assembly()