Rewrite pipeline test: compare runner vs reference with real weights, step-by-step

This commit is contained in:
2026-05-17 18:08:33 +00:00
parent e38d60a6e8
commit e51eafe288

View File

@@ -1,26 +1,17 @@
"""Test #2: End-to-end single-layer test with real model weights.
"""Pipeline Test: Compare CuTeDSL runner vs reference with real model weights.
Loads layer 0 from the DeepSeek-V4-Pro-NVFP4 checkpoint, runs one MoE layer
through our CuTeDSL runner, and compares against the reference moe_pipeline
(which uses the same NVFP4 weights but with dynamic gs).
This catches issues that the small layertest (3 experts, 8 tokens) misses:
- Scale assembly with 48 experts × 8 chunks
- Uneven expert assignment
- Real activation magnitudes
- swiglu_limit clamping
- Variable padded expert offsets at scale
Loads layer 0 from DeepSeek-V4-Pro-NVFP4, runs both the reference
moe_pipeline and our CuTeDSLMoERunner, compares output step by step.
"""
import torch
import sys
import os
import glob
import math
# Add paths
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/../cutedsl')
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/..')
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/../vllm')
from cutedsl.moe_pipeline import moe_pipeline
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
# ============================================================
@@ -28,7 +19,7 @@ from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
# ============================================================
MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
LAYER_IDX = 0
NUM_EXPERTS = 48 # local experts per EP rank
NUM_EXPERTS = 48
HIDDEN_SIZE = 7168
INTERMEDIATE_SIZE = 18432
NUM_TOKENS = 64
@@ -40,10 +31,6 @@ def load_expert_weights(layer_idx, num_experts):
"""Load NVFP4 weights for one layer from the checkpoint."""
from safetensors import safe_open
# Find the layer shard file
shard_dir = os.path.join(MODEL_PATH, f"model-0000{layer_idx+1:02d}-of-00010.safetensors")
# Try to find the right shard
import glob
shards = sorted(glob.glob(os.path.join(MODEL_PATH, "*.safetensors")))
l1_fp4 = []
@@ -56,19 +43,14 @@ def load_expert_weights(layer_idx, num_experts):
for shard_path in shards:
with safe_open(shard_path, framework="pt", device="cpu") as f:
for e in range(num_experts):
global_e = e # For rank 0, local = global
w13_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w13_weight"
sf13_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w13_weight_scale"
gs13_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w13_weight_scale_2"
w2_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w2_weight"
sf2_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w2_weight_scale"
gs2_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w2_weight_scale_2"
# L1 (gate+up)
w13_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w13_weight"
sf13_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w13_weight_scale"
gs13_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w13_weight_scale_2"
# L2 (down)
w2_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w2_weight"
sf2_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w2_weight_scale"
gs2_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w2_weight_scale_2"
if w13_key in f.keys():
if w13_key in f.keys() and len(l1_fp4) <= e:
l1_fp4.append(f.get_tensor(w13_key).to(DEVICE))
l1_sf.append(f.get_tensor(sf13_key).to(DEVICE))
l1_gs.append(f.get_tensor(gs13_key).to(DEVICE))
@@ -76,11 +58,11 @@ def load_expert_weights(layer_idx, num_experts):
l2_sf.append(f.get_tensor(sf2_key).to(DEVICE))
l2_gs.append(f.get_tensor(gs2_key).to(DEVICE))
if len(l1_fp4) == num_experts:
if len(l1_fp4) >= num_experts:
break
if len(l1_fp4) != num_experts:
raise RuntimeError(f"Only loaded {len(l1_fp4)}/{num_experts} experts from checkpoint")
raise RuntimeError(f"Only loaded {len(l1_fp4)}/{num_experts} experts")
return {
'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs,
@@ -88,28 +70,133 @@ def load_expert_weights(layer_idx, num_experts):
}
def run_reference(hidden_states, topk_weights, topk_ids, weights, swiglu_limit=None):
"""Reference MoE: per-expert processing with dynamic gs (quantize_to_nvfp4)."""
from cutedsl.quantize import quantize_to_nvfp4
from cutedsl.gemm import run_nvfp4_grouped_gemm
num_tokens = hidden_states.shape[0]
top_k = topk_ids.shape[1]
num_experts = len(weights['l1_fp4'])
# Sort tokens by expert
flat_ids = topk_ids.reshape(-1).cpu().numpy()
flat_weights = topk_weights.reshape(-1)
token_indices = torch.arange(num_tokens).unsqueeze(1).expand(-1, top_k).reshape(-1)
sort_idx = torch.argsort(topk_ids.reshape(-1), stable=True)
sorted_ids = topk_ids.reshape(-1)[sort_idx]
sorted_weights = topk_weights.reshape(-1)[sort_idx]
sorted_token_ids = token_indices[sort_idx]
# Compute expert offsets
expert_id_range = torch.arange(num_experts)
tokens_per_expert = (sorted_ids.unsqueeze(1).cpu() == expert_id_range.unsqueeze(0)).sum(dim=0)
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32)
for e in range(num_experts):
expert_offsets[e + 1] = expert_offsets[e] + tokens_per_expert[e].item()
num_slots = num_tokens * top_k
slot_hidden = hidden_states[sorted_token_ids]
# Stack weights for GEMM
l1_mat_b, l1_scale_b, l1_gsb = _stack_weights(weights['l1_fp4'], weights['l1_sf'], weights['l1_gs'])
l2_mat_b, l2_scale_b, l2_gsb = _stack_weights(weights['l2_fp4'], weights['l2_sf'], weights['l2_gs'])
# L1 with dynamic gs
x_fp4, x_sf, gs_val = quantize_to_nvfp4(slot_hidden)
l1_gsa = torch.full((num_experts,), gs_val, dtype=torch.float32, device=DEVICE)
l1_scale_a = _assemble_scales_ref(x_sf, expert_offsets, num_experts)
l1_out = run_nvfp4_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=l1_scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets[1:],
global_scale_a=l1_gsa, global_scale_b=l1_gsb,
)
# SiLU(gate) * up
gate = l1_out[:, :INTERMEDIATE_SIZE]
up = l1_out[:, INTERMEDIATE_SIZE:]
gate_silu = torch.nn.functional.silu(gate)
if swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=swiglu_limit)
up = up.clamp(min=-swiglu_limit, max=swiglu_limit)
activated = gate_silu * up
# L2 with dynamic gs
l2_x_fp4, l2_x_sf, l2_gs_val = quantize_to_nvfp4(activated)
l2_gsa = torch.full((num_experts,), l2_gs_val, dtype=torch.float32, device=DEVICE)
l2_scale_a = _assemble_scales_ref(l2_x_sf, expert_offsets, num_experts)
l2_out = run_nvfp4_grouped_gemm(
mat_a=l2_x_fp4, mat_b=l2_mat_b,
scale_a=l2_scale_a, scale_b=l2_scale_b,
expert_offsets=expert_offsets[1:],
global_scale_a=l2_gsa, global_scale_b=l2_gsb,
)
# Scatter-add
y = torch.zeros(num_tokens, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE)
weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype)
y.scatter_add_(0, sorted_token_ids.unsqueeze(1).expand(-1, HIDDEN_SIZE), weighted_out)
return y
def _stack_weights(fp4_list, sf_list, gs_list):
"""Stack expert weights into GEMM format."""
mat_b = torch.stack(fp4_list)
scale_b = torch.stack(sf_list)
gsb = torch.stack(gs_list)
return mat_b, scale_b, gsb
def _assemble_scales_ref(x_sf, expert_offsets, num_experts):
"""Reference scale assembly using assemble_scales_3d_side from bridge."""
from cutedsl.bridge import assemble_scales_3d_side
return assemble_scales_3d_side(x_sf, expert_offsets, num_experts)
def main():
torch.cuda.set_device(0)
torch.manual_seed(42)
print(f"=== Pipeline Test: Layer {LAYER_IDX}, {NUM_EXPERTS} experts, {NUM_TOKENS} tokens, top_k={TOP_K} ===")
print(f" swiglu_limit={SWIGLU_LIMIT}")
# Load real weights
print("Loading weights from checkpoint...")
print("\nLoading weights from checkpoint...")
weights = load_expert_weights(LAYER_IDX, NUM_EXPERTS)
print(f"Loaded {NUM_EXPERTS} experts")
for e in range(min(3, NUM_EXPERTS)):
print(f" Expert {e}: l1_fp4={weights['l1_fp4'][e].shape} l1_gs={weights['l1_gs'][e].item():.6f} "
f"l2_fp4={weights['l2_fp4'][e].shape} l2_gs={weights['l2_gs'][e].item():.6f}")
# Create runner
# Create input
hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE)
# Realistic top-k: uneven distribution
topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE)
for i in range(NUM_TOKENS):
experts = torch.randperm(NUM_EXPERTS)[:TOP_K]
topk_ids[i] = experts
topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K
# ---- Reference ----
print("\n--- Reference (dynamic gs, per-expert scale assembly) ---")
with torch.no_grad():
ref_out = run_reference(hidden_states, topk_weights, topk_ids, weights, swiglu_limit=SWIGLU_LIMIT)
print(f"Reference: amax={ref_out.amax().item():.4f} mean={ref_out.mean().item():.4f}")
print(f" NaN: {torch.isnan(ref_out).any().item()} Inf: {torch.isinf(ref_out).any().item()}")
# ---- Runner ----
print("\n--- CuTeDSL Runner (warmup gs, full-buffer swizzle) ---")
runner = CuTeDSLMoERunner(
num_experts=NUM_EXPERTS,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
max_num_tokens=NUM_TOKENS,
top_k=TOP_K,
device=DEVICE,
num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS,
top_k=TOP_K, device=DEVICE,
)
# Set weights
runner.l1_fp4 = weights['l1_fp4']
runner.l1_sf = weights['l1_sf']
runner.l1_gs = weights['l1_gs']
@@ -118,73 +205,40 @@ def main():
runner.l2_gs = weights['l2_gs']
runner.set_swiglu_limit(SWIGLU_LIMIT)
# Create input
hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE)
# Create top-k assignments (realistic: uneven distribution)
topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE)
for i in range(NUM_TOKENS):
# Each token picks TOP_K random experts
experts = torch.randperm(NUM_EXPERTS)[:TOP_K]
topk_ids[i] = experts
topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K
# ---- Stage 1: Reference pipeline (dynamic gs) ----
print("\n--- Reference pipeline (dynamic gs) ---")
with torch.no_grad():
ref_out = moe_pipeline(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
l1_fp4=weights['l1_fp4'],
l1_sf=weights['l1_sf'],
l1_gs=weights['l1_gs'],
l2_fp4=weights['l2_fp4'],
l2_sf=weights['l2_sf'],
l2_gs=weights['l2_gs'],
num_experts=NUM_EXPERTS,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
swiglu_limit=SWIGLU_LIMIT,
)
print(f"Reference: shape={ref_out.shape} amax={ref_out.amax().item():.4f} mean={ref_out.mean().item():.4f}")
print(f" NaN: {torch.isnan(ref_out).any().item()} Inf: {torch.isinf(ref_out).any().item()}")
# ---- Stage 2: Runner with warmup gs ----
print("\n--- Runner (warmup gs) ---")
with torch.no_grad():
# Compute warmup gs
runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids)
print(f"Warmup gs: L1={runner._l1_activation_global_scale:.6f} L2={runner._l2_activation_global_scale:.6f}")
# Run
runner_out = runner.run(hidden_states, topk_weights, topk_ids)
print(f"Runner: shape={runner_out.shape} amax={runner_out.amax().item():.4f} mean={runner_out.mean().item():.4f}")
print(f"Runner: amax={runner_out.amax().item():.4f} mean={runner_out.mean().item():.4f}")
print(f" NaN: {torch.isnan(runner_out).any().item()} Inf: {torch.isinf(runner_out).any().item()}")
# ---- Comparison ----
print("\n--- Comparison ---")
# Overall cosine
cos = torch.nn.functional.cosine_similarity(
ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0)
).item()
mse = (ref_out - runner_out).pow(2).mean().item()
print(f"Cosine: {cos:.6f} MSE: {mse:.4f}")
if cos < 0.90:
print("\n⚠️ LOW COSINE — investigating per-token differences...")
for i in range(min(NUM_TOKENS, 8)):
cos_i = torch.nn.functional.cosine_similarity(
ref_out[i].unsqueeze(0), runner_out[i].unsqueeze(0)
).item()
print(f" Token {i}: cosine={cos_i:.4f} ref_max={ref_out[i].amax().item():.4f} run_max={runner_out[i].amax().item():.4f}")
# Per-token
low_cos_tokens = 0
for i in range(NUM_TOKENS):
cos_i = torch.nn.functional.cosine_similarity(
ref_out[i].unsqueeze(0), runner_out[i].unsqueeze(0)
).item()
if cos_i < 0.95:
low_cos_tokens += 1
if low_cos_tokens <= 5:
print(f" Token {i}: cosine={cos_i:.4f} ref_max={ref_out[i].amax().item():.4f} run_max={runner_out[i].amax().item():.4f}")
if low_cos_tokens > 5:
print(f" ... {low_cos_tokens - 5} more tokens with cosine < 0.95")
if cos >= 0.98:
print(f"\n✅ PASS: cosine {cos:.6f} >= 0.98")
elif cos >= 0.90:
print(f"\n⚠️ MARGINAL: cosine {cos:.6f} — close but degraded")
print(f"\n⚠️ MARGINAL: cosine {cos:.6f}")
else:
print(f"\n❌ FAIL: cosine {cos:.6f} < 0.90 — significant quality loss")
print(f"\n❌ FAIL: cosine {cos:.6f}")
if __name__ == "__main__":