From 8dadd9a7238e6798eefe2a2f66904645947c35e7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 07:37:47 +0000 Subject: [PATCH] test: scale assembly debug --- tests/test_scale_debug.py | 72 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/test_scale_debug.py diff --git a/tests/test_scale_debug.py b/tests/test_scale_debug.py new file mode 100644 index 00000000..05738d58 --- /dev/null +++ b/tests/test_scale_debug.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +"""Deep-dive: compare scale assembly byte-by-byte for expert 0.""" +import os, sys, torch +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, REPO_ROOT) +from cutedsl.bridge import quantize_to_nvfp4, assemble_scales_2d_side +from cutedsl.kernel.moe.torch_scaled_grouped_mm import pad_and_swizzle_single, ceil_div +from vllm.nvfp4_cutedsl import CuTeDSLMoERunner + +DEVICE = "cuda" +num_experts = 3 +hidden_size = 7168 + +runner = CuTeDSLMoERunner(num_experts, hidden_size, 3072, device=DEVICE) +def rand_fp4(*shape): + return torch.randint(0, 256, shape, dtype=torch.uint8, device=DEVICE).view(torch.float4_e2m1fn_x2) +def rand_sf(*shape): + return torch.rand(shape, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn) +runner.prepare_weights_direct( + [rand_fp4(3584, 3072*2) for _ in range(num_experts)], + [rand_sf(3584//16, 3072*2) for _ in range(num_experts)], + [0.1]*num_experts, + [rand_fp4(1536, hidden_size) for _ in range(num_experts)], + [rand_sf(1536//16, hidden_size) for _ in range(num_experts)], + [0.1]*num_experts, +) +runner._ensure_stacked() + +# 8 tokens, expert 0 gets 4, expert 1 gets 3, expert 2 gets 1 +tokens_per_expert = [4, 3, 1] +total = sum(tokens_per_expert) +x = torch.randn(total, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 +x_fp4, x_sf, x_igs = quantize_to_nvfp4(x) + +# Reference: assemble_scales_2d_side +x_sf_parts = [x_sf[0:4], x_sf[4:7], x_sf[7:8]] +ref = assemble_scales_2d_side(x_sf_parts) + +# Cudagraph-safe +expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=DEVICE) +expert_offsets[1:] = torch.tensor(tokens_per_expert, dtype=torch.int32).cumsum(0) +cg = runner._assemble_scales_cudagraph_safe(x_sf, expert_offsets) + +print(f"ref shape: {ref.shape}, cg shape: {cg.shape}") + +# Compare expert 0's block (first 128 rows) +ref_e0 = ref[:128].view(torch.uint8) +cg_e0 = cg[:128].view(torch.uint8) +diff_e0 = (ref_e0 != cg_e0).sum().item() +print(f"Expert 0: {diff_e0}/{ref_e0.numel()} bytes differ") + +# Where do they differ? +if diff_e0 > 0: + diff_idx = torch.where(ref_e0.flatten() != cg_e0.flatten())[0] + for i in diff_idx[:20]: + print(f" byte {i}: ref={ref_e0.flatten()[i].item()}, cg={cg_e0.flatten()[i].item()}") + +# Also test: does pad_and_swizzle_single on the SAME input give the same output? +buf = torch.zeros(128, 448, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn) +buf[:4, :x_sf.shape[1]] = x_sf[0:4] +swizzled_buf = pad_and_swizzle_single(buf) + +# Now compare swizzled_buf vs the reference path's expert 0 swizzled block +# Reference path swizzles x_sf[0:4] padded to 128 rows +buf2 = torch.zeros(4, 448, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn) +buf2[:4, :x_sf.shape[1]] = x_sf[0:4] +# Wait, assemble_scales_2d_side pads to 128 rows first +buf3 = torch.zeros(128, 448, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn) +buf3[:4, :x_sf.shape[1]] = x_sf[0:4] +swizzled_ref = pad_and_swizzle_single(buf3) + +print(f"\nSame-input swizzle comparison: {torch.equal(swizzled_buf.view(torch.uint8), swizzled_ref.view(torch.uint8))}")