Files
nvfp4-megamoe-kernel/tests/test_scale_assembly.py

112 lines
4.5 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Test B: Compare _assemble_scales_cudagraph_safe vs assemble_scales_2d_side.
Both should produce identical output given the same x_sf and expert_offsets.
If they differ, the cudagraph-safe path has a bug.
Runs on the B200 host:
source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate
python3 tests/test_scale_assembly.py
"""
import os, sys, torch
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, REPO_ROOT)
from cutedsl.bridge import quantize_to_nvfp4, assemble_scales_2d_side
from cutedsl.kernel.moe.torch_scaled_grouped_mm import pad_and_swizzle_single, ceil_div
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
def test_scale_assembly():
"""Compare the two scale assembly methods with realistic data."""
DEVICE = "cuda"
num_experts = 3
hidden_size = 7168
intermediate_size = 3072
# Create a runner just to use its _assemble_scales_cudagraph_safe
runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=DEVICE)
# Trigger _ensure_stacked and buffer allocation with dummy weights
def rand_fp4(*shape):
return torch.randint(0, 256, shape, dtype=torch.uint8, device=DEVICE).view(torch.float4_e2m1fn_x2)
def rand_sf(*shape):
return torch.rand(shape, dtype=torch.float16, device=DEVICE).to(torch.float8_e4m3fn)
runner.prepare_weights_direct(
[rand_fp4(3584, intermediate_size * 2) for _ in range(num_experts)],
[rand_sf(3584 // 16, intermediate_size * 2) for _ in range(num_experts)],
[0.1] * num_experts,
[rand_fp4(1536, hidden_size) for _ in range(num_experts)],
[rand_sf(1536 // 16, hidden_size) for _ in range(num_experts)],
[0.1] * num_experts,
)
runner._ensure_stacked()
# Test with different token distributions
test_cases = [
("4 tokens, expert 0 gets 2, expert 1 gets 2, expert 2 gets 0", [2, 2, 0]),
("8 tokens, expert 0 gets 4, expert 1 gets 3, expert 2 gets 1", [4, 3, 1]),
("4 tokens, expert 0 gets 4, expert 1 gets 0, expert 2 gets 0", [4, 0, 0]),
("3 tokens, expert 0 gets 1, expert 1 gets 1, expert 2 gets 1", [1, 1, 1]),
]
all_pass = True
for desc, tokens_per_expert in test_cases:
total_tokens = sum(tokens_per_expert)
# Create input and quantize
x = torch.randn(total_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
x_fp4, x_sf, x_igs = quantize_to_nvfp4(x)
# Path 1: assemble_scales_2d_side (per-expert split)
x_sf_parts = []
offset = 0
for tpe in tokens_per_expert:
x_sf_parts.append(x_sf[offset:offset + tpe])
offset += tpe
scale_a_ref = assemble_scales_2d_side(x_sf_parts)
# Path 2: _assemble_scales_cudagraph_safe (GPU-only)
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=DEVICE)
expert_offsets[1:] = torch.tensor(tokens_per_expert, dtype=torch.int32).cumsum(0)
scale_a_cudagraph = runner._assemble_scales_cudagraph_safe(
x_sf, expert_offsets,
runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1
)
# Compare
# Note: shapes may differ due to padding, but the data in the
# padded rows should match (up to the total number of rows used by the kernel)
if scale_a_ref.shape != scale_a_cudagraph.shape:
print(f" {desc}")
print(f" Shape mismatch: ref={scale_a_ref.shape}, cg={scale_a_cudagraph.shape}")
all_pass = False
continue
match = torch.equal(scale_a_ref, scale_a_cudagraph)
if not match:
# Check how many bytes differ
diff = (scale_a_ref.view(torch.uint8) != scale_a_cudagraph.view(torch.uint8)).sum().item()
total = scale_a_ref.numel()
pct = diff / total * 100
print(f" {desc}")
print(f" MISMATCH: {diff}/{total} bytes differ ({pct:.1f}%)")
print(f" ref range: [{scale_a_ref.view(torch.uint8).min()}, {scale_a_ref.view(torch.uint8).max()}]")
print(f" cg range: [{scale_a_cudagraph.view(torch.uint8).min()}, {scale_a_cudagraph.view(torch.uint8).max()}]")
all_pass = False
else:
print(f" {desc}: ✅ MATCH")
print(f"\n{'=' * 70}")
if all_pass:
print(" ALL SCALE ASSEMBLY TESTS PASSED ✅")
else:
print(" SCALE ASSEMBLY TESTS FAILED ❌")
print(f"{'=' * 70}")
return all_pass
if __name__ == "__main__":
test_scale_assembly()