From e51eafe2887ba8d95ded789dd43ae256cf6014b7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 18:08:33 +0000 Subject: [PATCH] Rewrite pipeline test: compare runner vs reference with real weights, step-by-step --- tests/test_pipeline_real_weights.py | 238 +++++++++++++++++----------- 1 file changed, 146 insertions(+), 92 deletions(-) diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index 0f44fd09..784f8f72 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -1,26 +1,17 @@ -"""Test #2: End-to-end single-layer test with real model weights. +"""Pipeline Test: Compare CuTeDSL runner vs reference with real model weights. -Loads layer 0 from the DeepSeek-V4-Pro-NVFP4 checkpoint, runs one MoE layer -through our CuTeDSL runner, and compares against the reference moe_pipeline -(which uses the same NVFP4 weights but with dynamic gs). - -This catches issues that the small layertest (3 experts, 8 tokens) misses: -- Scale assembly with 48 experts × 8 chunks -- Uneven expert assignment -- Real activation magnitudes -- swiglu_limit clamping -- Variable padded expert offsets at scale +Loads layer 0 from DeepSeek-V4-Pro-NVFP4, runs both the reference +moe_pipeline and our CuTeDSLMoERunner, compares output step by step. """ import torch import sys import os +import glob import math -# Add paths -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/../cutedsl') +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/..') sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)) + '/../vllm') -from cutedsl.moe_pipeline import moe_pipeline from vllm.nvfp4_cutedsl import CuTeDSLMoERunner # ============================================================ @@ -28,7 +19,7 @@ from vllm.nvfp4_cutedsl import CuTeDSLMoERunner # ============================================================ MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" LAYER_IDX = 0 -NUM_EXPERTS = 48 # local experts per EP rank +NUM_EXPERTS = 48 HIDDEN_SIZE = 7168 INTERMEDIATE_SIZE = 18432 NUM_TOKENS = 64 @@ -40,10 +31,6 @@ def load_expert_weights(layer_idx, num_experts): """Load NVFP4 weights for one layer from the checkpoint.""" from safetensors import safe_open - # Find the layer shard file - shard_dir = os.path.join(MODEL_PATH, f"model-0000{layer_idx+1:02d}-of-00010.safetensors") - # Try to find the right shard - import glob shards = sorted(glob.glob(os.path.join(MODEL_PATH, "*.safetensors"))) l1_fp4 = [] @@ -56,19 +43,14 @@ def load_expert_weights(layer_idx, num_experts): for shard_path in shards: with safe_open(shard_path, framework="pt", device="cpu") as f: for e in range(num_experts): - global_e = e # For rank 0, local = global + w13_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w13_weight" + sf13_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w13_weight_scale" + gs13_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w13_weight_scale_2" + w2_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w2_weight" + sf2_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w2_weight_scale" + gs2_key = f"model.layers.{layer_idx}.mlp.experts.{e}.w2_weight_scale_2" - # L1 (gate+up) - w13_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w13_weight" - sf13_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w13_weight_scale" - gs13_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w13_weight_scale_2" - - # L2 (down) - w2_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w2_weight" - sf2_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w2_weight_scale" - gs2_key = f"model.layers.{layer_idx}.mlp.experts.{global_e}.w2_weight_scale_2" - - if w13_key in f.keys(): + if w13_key in f.keys() and len(l1_fp4) <= e: l1_fp4.append(f.get_tensor(w13_key).to(DEVICE)) l1_sf.append(f.get_tensor(sf13_key).to(DEVICE)) l1_gs.append(f.get_tensor(gs13_key).to(DEVICE)) @@ -76,11 +58,11 @@ def load_expert_weights(layer_idx, num_experts): l2_sf.append(f.get_tensor(sf2_key).to(DEVICE)) l2_gs.append(f.get_tensor(gs2_key).to(DEVICE)) - if len(l1_fp4) == num_experts: + if len(l1_fp4) >= num_experts: break if len(l1_fp4) != num_experts: - raise RuntimeError(f"Only loaded {len(l1_fp4)}/{num_experts} experts from checkpoint") + raise RuntimeError(f"Only loaded {len(l1_fp4)}/{num_experts} experts") return { 'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs, @@ -88,28 +70,133 @@ def load_expert_weights(layer_idx, num_experts): } +def run_reference(hidden_states, topk_weights, topk_ids, weights, swiglu_limit=None): + """Reference MoE: per-expert processing with dynamic gs (quantize_to_nvfp4).""" + from cutedsl.quantize import quantize_to_nvfp4 + from cutedsl.gemm import run_nvfp4_grouped_gemm + + num_tokens = hidden_states.shape[0] + top_k = topk_ids.shape[1] + num_experts = len(weights['l1_fp4']) + + # Sort tokens by expert + flat_ids = topk_ids.reshape(-1).cpu().numpy() + flat_weights = topk_weights.reshape(-1) + token_indices = torch.arange(num_tokens).unsqueeze(1).expand(-1, top_k).reshape(-1) + + sort_idx = torch.argsort(topk_ids.reshape(-1), stable=True) + sorted_ids = topk_ids.reshape(-1)[sort_idx] + sorted_weights = topk_weights.reshape(-1)[sort_idx] + sorted_token_ids = token_indices[sort_idx] + + # Compute expert offsets + expert_id_range = torch.arange(num_experts) + tokens_per_expert = (sorted_ids.unsqueeze(1).cpu() == expert_id_range.unsqueeze(0)).sum(dim=0) + expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32) + for e in range(num_experts): + expert_offsets[e + 1] = expert_offsets[e] + tokens_per_expert[e].item() + + num_slots = num_tokens * top_k + slot_hidden = hidden_states[sorted_token_ids] + + # Stack weights for GEMM + l1_mat_b, l1_scale_b, l1_gsb = _stack_weights(weights['l1_fp4'], weights['l1_sf'], weights['l1_gs']) + l2_mat_b, l2_scale_b, l2_gsb = _stack_weights(weights['l2_fp4'], weights['l2_sf'], weights['l2_gs']) + + # L1 with dynamic gs + x_fp4, x_sf, gs_val = quantize_to_nvfp4(slot_hidden) + l1_gsa = torch.full((num_experts,), gs_val, dtype=torch.float32, device=DEVICE) + + l1_scale_a = _assemble_scales_ref(x_sf, expert_offsets, num_experts) + + l1_out = run_nvfp4_grouped_gemm( + mat_a=x_fp4, mat_b=l1_mat_b, + scale_a=l1_scale_a, scale_b=l1_scale_b, + expert_offsets=expert_offsets[1:], + global_scale_a=l1_gsa, global_scale_b=l1_gsb, + ) + + # SiLU(gate) * up + gate = l1_out[:, :INTERMEDIATE_SIZE] + up = l1_out[:, INTERMEDIATE_SIZE:] + gate_silu = torch.nn.functional.silu(gate) + if swiglu_limit is not None: + gate_silu = gate_silu.clamp(max=swiglu_limit) + up = up.clamp(min=-swiglu_limit, max=swiglu_limit) + activated = gate_silu * up + + # L2 with dynamic gs + l2_x_fp4, l2_x_sf, l2_gs_val = quantize_to_nvfp4(activated) + l2_gsa = torch.full((num_experts,), l2_gs_val, dtype=torch.float32, device=DEVICE) + l2_scale_a = _assemble_scales_ref(l2_x_sf, expert_offsets, num_experts) + + l2_out = run_nvfp4_grouped_gemm( + mat_a=l2_x_fp4, mat_b=l2_mat_b, + scale_a=l2_scale_a, scale_b=l2_scale_b, + expert_offsets=expert_offsets[1:], + global_scale_a=l2_gsa, global_scale_b=l2_gsb, + ) + + # Scatter-add + y = torch.zeros(num_tokens, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) + weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype) + y.scatter_add_(0, sorted_token_ids.unsqueeze(1).expand(-1, HIDDEN_SIZE), weighted_out) + + return y + + +def _stack_weights(fp4_list, sf_list, gs_list): + """Stack expert weights into GEMM format.""" + mat_b = torch.stack(fp4_list) + scale_b = torch.stack(sf_list) + gsb = torch.stack(gs_list) + return mat_b, scale_b, gsb + + +def _assemble_scales_ref(x_sf, expert_offsets, num_experts): + """Reference scale assembly using assemble_scales_3d_side from bridge.""" + from cutedsl.bridge import assemble_scales_3d_side + return assemble_scales_3d_side(x_sf, expert_offsets, num_experts) + + def main(): torch.cuda.set_device(0) torch.manual_seed(42) print(f"=== Pipeline Test: Layer {LAYER_IDX}, {NUM_EXPERTS} experts, {NUM_TOKENS} tokens, top_k={TOP_K} ===") + print(f" swiglu_limit={SWIGLU_LIMIT}") - # Load real weights - print("Loading weights from checkpoint...") + print("\nLoading weights from checkpoint...") weights = load_expert_weights(LAYER_IDX, NUM_EXPERTS) print(f"Loaded {NUM_EXPERTS} experts") + for e in range(min(3, NUM_EXPERTS)): + print(f" Expert {e}: l1_fp4={weights['l1_fp4'][e].shape} l1_gs={weights['l1_gs'][e].item():.6f} " + f"l2_fp4={weights['l2_fp4'][e].shape} l2_gs={weights['l2_gs'][e].item():.6f}") - # Create runner + # Create input + hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) + + # Realistic top-k: uneven distribution + topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE) + for i in range(NUM_TOKENS): + experts = torch.randperm(NUM_EXPERTS)[:TOP_K] + topk_ids[i] = experts + topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K + + # ---- Reference ---- + print("\n--- Reference (dynamic gs, per-expert scale assembly) ---") + with torch.no_grad(): + ref_out = run_reference(hidden_states, topk_weights, topk_ids, weights, swiglu_limit=SWIGLU_LIMIT) + print(f"Reference: amax={ref_out.amax().item():.4f} mean={ref_out.mean().item():.4f}") + print(f" NaN: {torch.isnan(ref_out).any().item()} Inf: {torch.isinf(ref_out).any().item()}") + + # ---- Runner ---- + print("\n--- CuTeDSL Runner (warmup gs, full-buffer swizzle) ---") runner = CuTeDSLMoERunner( - num_experts=NUM_EXPERTS, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - max_num_tokens=NUM_TOKENS, - top_k=TOP_K, - device=DEVICE, + num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS, + top_k=TOP_K, device=DEVICE, ) - - # Set weights runner.l1_fp4 = weights['l1_fp4'] runner.l1_sf = weights['l1_sf'] runner.l1_gs = weights['l1_gs'] @@ -118,73 +205,40 @@ def main(): runner.l2_gs = weights['l2_gs'] runner.set_swiglu_limit(SWIGLU_LIMIT) - # Create input - hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) - - # Create top-k assignments (realistic: uneven distribution) - topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE) - for i in range(NUM_TOKENS): - # Each token picks TOP_K random experts - experts = torch.randperm(NUM_EXPERTS)[:TOP_K] - topk_ids[i] = experts - topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K - - # ---- Stage 1: Reference pipeline (dynamic gs) ---- - print("\n--- Reference pipeline (dynamic gs) ---") with torch.no_grad(): - ref_out = moe_pipeline( - hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - l1_fp4=weights['l1_fp4'], - l1_sf=weights['l1_sf'], - l1_gs=weights['l1_gs'], - l2_fp4=weights['l2_fp4'], - l2_sf=weights['l2_sf'], - l2_gs=weights['l2_gs'], - num_experts=NUM_EXPERTS, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - swiglu_limit=SWIGLU_LIMIT, - ) - print(f"Reference: shape={ref_out.shape} amax={ref_out.amax().item():.4f} mean={ref_out.mean().item():.4f}") - print(f" NaN: {torch.isnan(ref_out).any().item()} Inf: {torch.isinf(ref_out).any().item()}") - - # ---- Stage 2: Runner with warmup gs ---- - print("\n--- Runner (warmup gs) ---") - with torch.no_grad(): - # Compute warmup gs runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids) print(f"Warmup gs: L1={runner._l1_activation_global_scale:.6f} L2={runner._l2_activation_global_scale:.6f}") - - # Run runner_out = runner.run(hidden_states, topk_weights, topk_ids) - print(f"Runner: shape={runner_out.shape} amax={runner_out.amax().item():.4f} mean={runner_out.mean().item():.4f}") + print(f"Runner: amax={runner_out.amax().item():.4f} mean={runner_out.mean().item():.4f}") print(f" NaN: {torch.isnan(runner_out).any().item()} Inf: {torch.isinf(runner_out).any().item()}") # ---- Comparison ---- print("\n--- Comparison ---") - # Overall cosine cos = torch.nn.functional.cosine_similarity( ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0) ).item() mse = (ref_out - runner_out).pow(2).mean().item() print(f"Cosine: {cos:.6f} MSE: {mse:.4f}") - if cos < 0.90: - print("\n⚠️ LOW COSINE — investigating per-token differences...") - for i in range(min(NUM_TOKENS, 8)): - cos_i = torch.nn.functional.cosine_similarity( - ref_out[i].unsqueeze(0), runner_out[i].unsqueeze(0) - ).item() - print(f" Token {i}: cosine={cos_i:.4f} ref_max={ref_out[i].amax().item():.4f} run_max={runner_out[i].amax().item():.4f}") + # Per-token + low_cos_tokens = 0 + for i in range(NUM_TOKENS): + cos_i = torch.nn.functional.cosine_similarity( + ref_out[i].unsqueeze(0), runner_out[i].unsqueeze(0) + ).item() + if cos_i < 0.95: + low_cos_tokens += 1 + if low_cos_tokens <= 5: + print(f" Token {i}: cosine={cos_i:.4f} ref_max={ref_out[i].amax().item():.4f} run_max={runner_out[i].amax().item():.4f}") + if low_cos_tokens > 5: + print(f" ... {low_cos_tokens - 5} more tokens with cosine < 0.95") if cos >= 0.98: print(f"\n✅ PASS: cosine {cos:.6f} >= 0.98") elif cos >= 0.90: - print(f"\n⚠️ MARGINAL: cosine {cos:.6f} — close but degraded") + print(f"\n⚠️ MARGINAL: cosine {cos:.6f}") else: - print(f"\n❌ FAIL: cosine {cos:.6f} < 0.90 — significant quality loss") + print(f"\n❌ FAIL: cosine {cos:.6f}") if __name__ == "__main__":