2026-05-17 07:33:20 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
"""
|
|
|
|
|
Test B: Compare _assemble_scales_cudagraph_safe vs assemble_scales_2d_side.
|
|
|
|
|
|
|
|
|
|
Both should produce identical output given the same x_sf and expert_offsets.
|
|
|
|
|
If they differ, the cudagraph-safe path has a bug.
|
|
|
|
|
|
|
|
|
|
Runs on the B200 host:
|
|
|
|
|
source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate
|
|
|
|
|
python3 tests/test_scale_assembly.py
|
|
|
|
|
"""
|
|
|
|
|
import os, sys, torch
|
|
|
|
|
|
|
|
|
|
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
|
sys.path.insert(0, REPO_ROOT)
|
|
|
|
|
|
|
|
|
|
from cutedsl.bridge import quantize_to_nvfp4, assemble_scales_2d_side
|
|
|
|
|
from cutedsl.kernel.moe.torch_scaled_grouped_mm import pad_and_swizzle_single, ceil_div
|
|
|
|
|
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_scale_assembly():
|
|
|
|
|
"""Compare the two scale assembly methods with realistic data."""
|
|
|
|
|
DEVICE = "cuda"
|
|
|
|
|
num_experts = 3
|
|
|
|
|
hidden_size = 7168
|
|
|
|
|
intermediate_size = 3072
|
|
|
|
|
|
|
|
|
|
# Create a runner just to use its _assemble_scales_cudagraph_safe
|
|
|
|
|
runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE)
|
|
|
|
|
# Trigger _ensure_stacked and buffer allocation with dummy weights
|
|
|
|
|
def rand_fp4(*shape):
|
|
|
|
|
return torch.randint(0, 256, shape, dtype=torch.uint8, device=DEVICE).view(torch.float4_e2m1fn_x2)
|
|
|
|
|
def rand_sf(*shape):
|
|
|
|
|
return torch.rand(shape, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn)
|
|
|
|
|
runner.prepare_weights_direct(
|
|
|
|
|
[rand_fp4(3584, intermediate_size * 2) for _ in range(num_experts)],
|
|
|
|
|
[rand_sf(3584 // 16, intermediate_size * 2) for _ in range(num_experts)],
|
|
|
|
|
[0.1] * num_experts,
|
|
|
|
|
[rand_fp4(1536, hidden_size) for _ in range(num_experts)],
|
|
|
|
|
[rand_sf(1536 // 16, hidden_size) for _ in range(num_experts)],
|
|
|
|
|
[0.1] * num_experts,
|
|
|
|
|
)
|
|
|
|
|
runner._ensure_stacked()
|
|
|
|
|
|
|
|
|
|
# Test with different token distributions
|
|
|
|
|
test_cases = [
|
|
|
|
|
("4 tokens, expert 0 gets 2, expert 1 gets 2, expert 2 gets 0", [2, 2, 0]),
|
|
|
|
|
("8 tokens, expert 0 gets 4, expert 1 gets 3, expert 2 gets 1", [4, 3, 1]),
|
|
|
|
|
("4 tokens, expert 0 gets 4, expert 1 gets 0, expert 2 gets 0", [4, 0, 0]),
|
|
|
|
|
("3 tokens, expert 0 gets 1, expert 1 gets 1, expert 2 gets 1", [1, 1, 1]),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
all_pass = True
|
|
|
|
|
for desc, tokens_per_expert in test_cases:
|
|
|
|
|
total_tokens = sum(tokens_per_expert)
|
|
|
|
|
|
|
|
|
|
# Create input and quantize
|
|
|
|
|
x = torch.randn(total_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
|
|
|
|
x_fp4, x_sf, x_igs = quantize_to_nvfp4(x)
|
|
|
|
|
|
|
|
|
|
# Path 1: assemble_scales_2d_side (per-expert split)
|
|
|
|
|
x_sf_parts = []
|
|
|
|
|
offset = 0
|
|
|
|
|
for tpe in tokens_per_expert:
|
|
|
|
|
x_sf_parts.append(x_sf[offset:offset + tpe])
|
|
|
|
|
offset += tpe
|
|
|
|
|
scale_a_ref = assemble_scales_2d_side(x_sf_parts)
|
|
|
|
|
|
|
|
|
|
# Path 2: _assemble_scales_cudagraph_safe (GPU-only)
|
|
|
|
|
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=DEVICE)
|
|
|
|
|
expert_offsets[1:] = torch.tensor(tokens_per_expert, dtype=torch.int32).cumsum(0)
|
2026-05-17 07:43:05 +00:00
|
|
|
scale_a_cudagraph = runner._assemble_scales_cudagraph_safe(
|
|
|
|
|
x_sf, expert_offsets,
|
|
|
|
|
runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1
|
|
|
|
|
)
|
2026-05-17 07:33:20 +00:00
|
|
|
|
|
|
|
|
# Compare
|
|
|
|
|
# Note: shapes may differ due to padding, but the data in the
|
|
|
|
|
# padded rows should match (up to the total number of rows used by the kernel)
|
|
|
|
|
if scale_a_ref.shape != scale_a_cudagraph.shape:
|
|
|
|
|
print(f" {desc}")
|
|
|
|
|
print(f" Shape mismatch: ref={scale_a_ref.shape}, cg={scale_a_cudagraph.shape}")
|
|
|
|
|
all_pass = False
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
match = torch.equal(scale_a_ref, scale_a_cudagraph)
|
|
|
|
|
if not match:
|
|
|
|
|
# Check how many bytes differ
|
|
|
|
|
diff = (scale_a_ref.view(torch.uint8) != scale_a_cudagraph.view(torch.uint8)).sum().item()
|
|
|
|
|
total = scale_a_ref.numel()
|
|
|
|
|
pct = diff / total * 100
|
|
|
|
|
print(f" {desc}")
|
|
|
|
|
print(f" MISMATCH: {diff}/{total} bytes differ ({pct:.1f}%)")
|
|
|
|
|
print(f" ref range: [{scale_a_ref.view(torch.uint8).min()}, {scale_a_ref.view(torch.uint8).max()}]")
|
|
|
|
|
print(f" cg range: [{scale_a_cudagraph.view(torch.uint8).min()}, {scale_a_cudagraph.view(torch.uint8).max()}]")
|
|
|
|
|
all_pass = False
|
|
|
|
|
else:
|
|
|
|
|
print(f" {desc}: ✅ MATCH")
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 70}")
|
|
|
|
|
if all_pass:
|
|
|
|
|
print(" ALL SCALE ASSEMBLY TESTS PASSED ✅")
|
|
|
|
|
else:
|
|
|
|
|
print(" SCALE ASSEMBLY TESTS FAILED ❌")
|
|
|
|
|
print(f"{'=' * 70}")
|
|
|
|
|
return all_pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
test_scale_assembly()
|