From 9728604ea1a7e5d1c307513aeba13f595b4fb35f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 21:19:17 +0000 Subject: [PATCH] Pipeline test: stage-by-stage with BF16 reference comparison --- tests/test_pipeline_real_weights.py | 513 ++++++++++++++++++---------- 1 file changed, 335 insertions(+), 178 deletions(-) diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index a7174f39..2e2f2573 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -1,15 +1,17 @@ -"""Step-by-step Pipeline Test: CuTeDSL runner components vs reference. +"""Step-by-step Pipeline Test: Isolate each component of CuTeDSL runner. Loads real layer 0 weights from DeepSeek-V4-Pro-NVFP4. -Tests each pipeline stage independently: - 1. Token sorting & expert assignment - 2. L1 GEMM (gate+up) - 3. SwiGLU activation (with swiglu_limit clamping) - 4. L2 GEMM (down_proj) - 5. Scatter-add with routing weights - 6. Full runner vs reference +Tests each pipeline stage independently against a BF16 reference: + Stage 1: Token sort + expert assignment + Stage 2: L1 GEMM (gate+up) + Stage 3: SwiGLU activation (with swiglu_limit clamping) + Stage 4: L2 GEMM (down_proj) + Stage 5: Scatter-add with routing weights + Stage 6: Full runner end-to-end -Strategy: Comment out stages to isolate bugs, then uncomment one by one. +Strategy: Stages are tested incrementally. Enable STAGE_START to begin at that stage. +All stages from STAGE_START through STAGE_END are tested. +Set STAGE_START=1 STAGE_END=1 to test only stage 1, etc. """ import torch import torch.nn.functional as F @@ -21,28 +23,24 @@ import glob sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) # ============================================================ -# CONFIG — toggle which stages to test +# CONFIG # ============================================================ MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" LAYER_IDX = 0 -NUM_EXPERTS = 48 # local experts per rank (256/8=32, but model uses 48) +NUM_EXPERTS = 48 HIDDEN_SIZE = 7168 -INTERMEDIATE_SIZE = 3072 # per routed expert (18432 is shared expert) +INTERMEDIATE_SIZE = 3072 # per routed expert NUM_TOKENS = 8 TOP_K = 6 SWIGLU_LIMIT = 10.0 DEVICE = "cuda" -# Which stages to enable (uncomment incrementally to find bugs) -ENABLE_SORT = True -ENABLE_L1_GEMM = True -ENABLE_SWIGLU = True -ENABLE_L2_GEMM = True -ENABLE_SCATTER = True -ENABLE_FULL_RUNNER = True +# Stage control (1-6) +STAGE_START = 1 +STAGE_END = 6 # ============================================================ -# Weight loading (from layertest.py pattern) +# Weight loading # ============================================================ def load_layer_tensors(model_dir, layer_idx): tensors = {} @@ -50,16 +48,13 @@ def load_layer_tensors(model_dir, layer_idx): from safetensors.torch import load_file data = load_file(sf) for k, v in data.items(): - # Match both "layers.X." and "model.layers.X." if f"layers.{layer_idx}." in k and "mlp.experts" in k: - # Normalize: strip "model." prefix if present norm_key = k.removeprefix("model.") tensors[norm_key] = v return tensors def prepare_nvfp4_weights(nvfp4_tensors, layer_idx, expert_indices, intermediate_size): - """Prepare weights via direct view-cast (same as layertest).""" l1_fp4, l1_sf, l1_gs = [], [], [] l2_fp4, l2_sf, l2_gs = [], [], [] @@ -109,8 +104,7 @@ def prepare_nvfp4_weights(nvfp4_tensors, layer_idx, expert_indices, intermediate def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale): - """Dequantize NVFP4 weight to BF16 for reference computation. - + """Dequantize NVFP4 weight to BF16. packed_uint8: (N, K_packed) where K_packed = K//2 scale_e4m3: (N, K_sf) where K_sf = K//16 Returns: (N, K) BF16 @@ -126,28 +120,24 @@ def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale): K = packed_uint8.shape[1] * 2 bf16_vals = torch.stack([lower, upper], dim=-1).reshape(N, K) - - # scale_e4m3 is (N, K_sf) where K_sf = K//16 K_sf = scale_e4m3.shape[1] - scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=1) # (N, K) - + scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=1) dequant = bf16_vals * scale_2d * global_scale return dequant.to(torch.bfloat16) # ============================================================ -# Reference pipeline (step by step, BF16) +# Stage tests # ============================================================ -def reference_moe_bf16(hidden_states, nvfp4_tensors, layer_idx, expert_indices, topk_ids, topk_weights, swiglu_limit): - """BF16 reference: dequantize weights, run MoE step by step.""" +def test_stage1_sort(hidden_states, topk_ids, topk_weights, expert_indices): + """Stage 1: Token sorting & expert assignment.""" + print("\n--- Stage 1: Token Sort & Expert Assignment ---") + num_tokens = hidden_states.shape[0] top_k = topk_ids.shape[1] - output = torch.zeros(num_tokens, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) - - # Store intermediate results for comparison - intermediates = {} - - # Sort tokens by expert (for comparing with runner's sorted approach) + num_experts = len(expert_indices) + + # Reference: simple sort by expert ID flat_ids = topk_ids.reshape(-1) flat_weights = topk_weights.reshape(-1) sort_idx = flat_ids.argsort(stable=True) @@ -155,145 +145,200 @@ def reference_moe_bf16(hidden_states, nvfp4_tensors, layer_idx, expert_indices, sorted_weights = flat_weights[sort_idx] token_indices = torch.arange(num_tokens, device=DEVICE).unsqueeze(1).expand(-1, top_k).reshape(-1) sorted_token_ids = token_indices[sort_idx] + + expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=DEVICE) + for i in range(num_experts): + expert_offsets[i + 1] = (sorted_ids == i).sum() + expert_offsets[1:] = expert_offsets[1:].cumsum(0) + + tokens_per_expert = expert_offsets[1:] - expert_offsets[:-1] + + print(f" Tokens per expert: min={tokens_per_expert.min().item()} max={tokens_per_expert.max().item()} total={tokens_per_expert.sum().item()}") + print(f" Expert offsets: {expert_offsets.tolist()}") + print(f" Sorted token IDs (first 20): {sorted_token_ids[:20].tolist()}") + + return { + 'sorted_ids': sorted_ids, + 'sorted_token_ids': sorted_token_ids, + 'sorted_weights': sorted_weights, + 'expert_offsets': expert_offsets, + 'tokens_per_expert': tokens_per_expert, + 'slot_hidden': hidden_states[sorted_token_ids], + } - intermediates['sorted_ids'] = sorted_ids - intermediates['sorted_token_ids'] = sorted_token_ids - intermediates['sorted_weights'] = sorted_weights - - # Expert offsets - expert_id_range = torch.arange(len(expert_indices), device=DEVICE) - tokens_per_expert = torch.zeros(len(expert_indices), dtype=torch.int32, device=DEVICE) - for i, e in enumerate(expert_indices): - tokens_per_expert[i] = (sorted_ids == i).sum() - expert_offsets = torch.zeros(len(expert_indices) + 1, dtype=torch.int32, device=DEVICE) - expert_offsets[1:] = tokens_per_expert.cumsum(0) - - intermediates['expert_offsets'] = expert_offsets - intermediates['tokens_per_expert'] = tokens_per_expert - - # Gather hidden states for sorted tokens - slot_hidden = hidden_states[sorted_token_ids] - intermediates['slot_hidden'] = slot_hidden - - # Per-expert computation - l1_out_all = [] - activated_all = [] - l2_out_all = [] +def test_stage2_l1_gemm(slot_hidden, expert_offsets, nvfp4_tensors, layer_idx, expert_indices, weights): + """Stage 2: L1 GEMM (gate+up) using CuTeDSL bridge directly.""" + print("\n--- Stage 2: L1 GEMM (gate+up) ---") + from cutedsl.bridge import ( + quantize_to_nvfp4, run_nvfp4_grouped_gemm, + assemble_scales_3d_side, make_b_k_major, + ) + + num_experts = len(expert_indices) + + # Stack weights for GEMM + l1_mat_b = torch.stack(weights['l1_fp4']) + l1_scale_b = torch.stack(weights['l1_sf']) + l1_gsb = torch.stack(weights['l1_gs']) + + # Make B-K major + l1_mat_b = make_b_k_major(l1_mat_b) + l1_scale_b = assemble_scales_3d_side(l1_scale_b) + + # Quantize activation with dynamic gs + x_fp4, x_sf, l1_gs_dyn = quantize_to_nvfp4(slot_hidden) + l1_gsa = torch.full((num_experts,), l1_gs_dyn, dtype=torch.float32, device=DEVICE) + l1_scale_a = assemble_scales_3d_side(x_sf, expert_offsets[:num_experts+1], num_experts) + + # Run GEMM + l1_out = run_nvfp4_grouped_gemm( + mat_a=x_fp4, mat_b=l1_mat_b, + scale_a=l1_scale_a, scale_b=l1_scale_b, + expert_offsets=expert_offsets[1:], + global_scale_a=l1_gsa, global_scale_b=l1_gsb, + ) + print(f" L1 gs (dynamic): {l1_gs_dyn:.6f}") + print(f" L1 out: shape={l1_out.shape} amax={l1_out.amax().item():.4f} mean={l1_out.mean().item():.4f}") + print(f" L1 out NaN: {torch.isnan(l1_out).any().item()} Inf: {torch.isinf(l1_out).any().item()}") + + # BF16 reference for first expert + ref_l1_parts = [] for i, e in enumerate(expert_indices): start = expert_offsets[i].item() end = expert_offsets[i + 1].item() if start == end: continue + x = slot_hidden[start:end] + gate_w = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) + up_w = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) + gate_sf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) + up_sf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) + gate_gs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() + up_gs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale_2"].item() + + gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf, gate_gs) + up_bf16 = dequantize_nvfp4_weight(up_w, up_sf, up_gs) + + gate_ref = x @ gate_bf16.T + up_ref = x @ up_bf16.T + ref_l1_parts.append((start, end, gate_ref, up_ref)) + + # Compare L1 output for first expert that has tokens + if ref_l1_parts: + start, end, gate_ref, up_ref = ref_l1_parts[0] + l1_gate = l1_out[start:end, :INTERMEDIATE_SIZE] + l1_up = l1_out[start:end, INTERMEDIATE_SIZE:] + + cos_gate = F.cosine_similarity(gate_ref.flatten().unsqueeze(0), l1_gate.flatten().unsqueeze(0)).item() + cos_up = F.cosine_similarity(up_ref.flatten().unsqueeze(0), l1_up.flatten().unsqueeze(0)).item() + print(f" L1 vs BF16 (expert {expert_indices[0]}, {end-start} tokens):") + print(f" gate: cosine={cos_gate:.6f} ref_amax={gate_ref.amax().item():.4f} run_amax={l1_gate.amax().item():.4f}") + print(f" up: cosine={cos_up:.6f} ref_amax={up_ref.amax().item():.4f} run_amax={l1_up.amax().item():.4f}") + + return l1_out, ref_l1_parts, l1_gs_dyn - x = slot_hidden[start:end] # (T, H) - # L1: gate + up - gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) - up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) - gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) - up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) - gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() - up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() - - gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf, gate_gs) # (intermediate, hidden) - up_bf16 = dequantize_nvfp4_weight(up_w, up_sf, up_gs) # (intermediate, hidden) - - gate = x @ gate_bf16.T # (T, intermediate) - up = x @ up_bf16.T # (T, intermediate) - - l1_out = torch.cat([gate, up], dim=1) # (T, 2*intermediate) - l1_out_all.append((start, end, l1_out)) - - # SwiGLU - gate_silu = F.silu(gate) +def test_stage3_swiglu(l1_out, ref_l1_parts, swiglu_limit): + """Stage 3: SwiGLU activation with clamping.""" + print("\n--- Stage 3: SwiGLU Activation ---") + + # Runner path + gate = l1_out[:, :INTERMEDIATE_SIZE] + up = l1_out[:, INTERMEDIATE_SIZE:] + gate_silu = F.silu(gate) + if swiglu_limit is not None: + gate_silu = gate_silu.clamp(max=swiglu_limit) + up = up.clamp(min=-swiglu_limit, max=swiglu_limit) + activated = gate_silu * up + + print(f" activated: shape={activated.shape} amax={activated.amax().item():.4f} mean={activated.mean().item():.4f}") + print(f" gate_silu amax: {gate_silu.amax().item():.4f} up amax: {up.amax().item():.4f}") + + # BF16 reference + if ref_l1_parts: + start, end, gate_ref, up_ref = ref_l1_parts[0] + gate_silu_ref = F.silu(gate_ref) if swiglu_limit is not None: - gate_silu = gate_silu.clamp(max=swiglu_limit) - up = up.clamp(min=-swiglu_limit, max=swiglu_limit) - activated = gate_silu * up - activated_all.append((start, end, activated)) + gate_silu_ref = gate_silu_ref.clamp(max=swiglu_limit) + up_ref = up_ref.clamp(min=-swiglu_limit, max=swiglu_limit) + activated_ref = gate_silu_ref * up_ref + + act_runner = activated[start:end] + cos = F.cosine_similarity(activated_ref.flatten().unsqueeze(0), act_runner.flatten().unsqueeze(0)).item() + print(f" vs BF16 (expert 0): cosine={cos:.6f}") + + return activated - # L2: down + +def test_stage4_l2_gemm(activated, expert_offsets, nvfp4_tensors, layer_idx, expert_indices, weights): + """Stage 4: L2 GEMM (down_proj).""" + print("\n--- Stage 4: L2 GEMM (down_proj) ---") + from cutedsl.bridge import ( + quantize_to_nvfp4, run_nvfp4_grouped_gemm, + assemble_scales_3d_side, make_b_k_major, + ) + + num_experts = len(expert_indices) + + l2_mat_b = torch.stack(weights['l2_fp4']) + l2_scale_b = torch.stack(weights['l2_sf']) + l2_gsb = torch.stack(weights['l2_gs']) + + l2_mat_b = make_b_k_major(l2_mat_b) + l2_scale_b = assemble_scales_3d_side(l2_scale_b) + + l2_x_fp4, l2_x_sf, l2_gs_dyn = quantize_to_nvfp4(activated) + l2_gsa = torch.full((num_experts,), l2_gs_dyn, dtype=torch.float32, device=DEVICE) + l2_scale_a = assemble_scales_3d_side(l2_x_sf, expert_offsets[:num_experts+1], num_experts) + + l2_out = run_nvfp4_grouped_gemm( + mat_a=l2_x_fp4, mat_b=l2_mat_b, + scale_a=l2_scale_a, scale_b=l2_scale_b, + expert_offsets=expert_offsets[1:], + global_scale_a=l2_gsa, global_scale_b=l2_gsb, + ) + print(f" L2 gs (dynamic): {l2_gs_dyn:.6f}") + print(f" L2 out: shape={l2_out.shape} amax={l2_out.amax().item():.4f} mean={l2_out.mean().item():.4f}") + + # BF16 reference for first expert + if expert_offsets[1] > 0: + e = expert_indices[0] + start = 0 + end = expert_offsets[1].item() down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" if down_key in nvfp4_tensors: down_w = nvfp4_tensors[down_key].to(DEVICE) down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() - down_bf16 = dequantize_nvfp4_weight(down_w, down_sf, down_gs) # (hidden, intermediate) - l2_out = activated @ down_bf16.T # (T, H) - else: - l2_out = activated[:, :HIDDEN_SIZE] - - l2_out_all.append((start, end, l2_out)) - - # Scatter-add - weighted = l2_out * sorted_weights[start:end].unsqueeze(1).to(l2_out.dtype) - output.scatter_add_(0, sorted_token_ids[start:end].unsqueeze(1).expand(-1, HIDDEN_SIZE), weighted) - - intermediates['l1_out_all'] = l1_out_all - intermediates['activated_all'] = activated_all - intermediates['l2_out_all'] = l2_out_all - intermediates['output'] = output - - return intermediates + down_bf16 = dequantize_nvfp4_weight(down_w, down_sf, down_gs) + + # Need activated reference — use the one we computed in stage 3 + gate = l2_out[:end] # Will compare against runner's L2 + ref_l2 = activated[start:end] @ down_bf16.T + cos = F.cosine_similarity(ref_l2.flatten().unsqueeze(0), gate.flatten().unsqueeze(0)).item() + print(f" vs BF16 (expert 0): cosine={cos:.6f} ref_amax={ref_l2.amax().item():.4f} run_amax={gate.amax().item():.4f}") + + return l2_out, l2_gs_dyn -# ============================================================ -# Main test -# ============================================================ -def main(): - torch.cuda.set_device(0) - torch.manual_seed(42) +def test_stage5_scatter(l2_out, expert_offsets, sorted_token_ids, sorted_weights, num_tokens): + """Stage 5: Scatter-add with routing weights.""" + print("\n--- Stage 5: Scatter-add ---") + + output = torch.zeros(num_tokens, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) + weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype) + output.scatter_add_(0, sorted_token_ids.unsqueeze(1).expand(-1, HIDDEN_SIZE), weighted_out) + + print(f" Output: shape={output.shape} amax={output.amax().item():.4f} mean={output.mean().item():.4f}") + return output - print(f"=== Step-by-Step Pipeline Test ===") - print(f" Experts: {NUM_EXPERTS}, H={HIDDEN_SIZE}, I={INTERMEDIATE_SIZE}") - print(f" Tokens: {NUM_TOKENS}, top_k={TOP_K}, swiglu_limit={SWIGLU_LIMIT}") - print(f" Stages: sort={ENABLE_SORT} L1={ENABLE_L1_GEMM} swiglu={ENABLE_SWIGLU} L2={ENABLE_L2_GEMM} scatter={ENABLE_SCATTER} full={ENABLE_FULL_RUNNER}") - # Load real weights - print("\n[1/6] Loading checkpoint...") - nvfp4_tensors = load_layer_tensors(MODEL_PATH, LAYER_IDX) - print(f" {len(nvfp4_tensors)} tensors loaded") - - # Figure out expert indices for this rank - # layer 0 has experts 0-255, we use first NUM_EXPERTS - expert_indices = list(range(NUM_EXPERTS)) - print(f" Using experts: {expert_indices[:5]}... (first 5 of {NUM_EXPERTS})") - - print("\n[2/6] Preparing NVFP4 weights (direct view-cast)...") - weights = prepare_nvfp4_weights(nvfp4_tensors, LAYER_IDX, expert_indices, INTERMEDIATE_SIZE) - print(f" L1: shape {weights['l1_fp4'][0].shape} dtype {weights['l1_fp4'][0].dtype}") - print(f" L2: shape {weights['l2_fp4'][0].shape} dtype {weights['l2_fp4'][0].dtype}") - - # Create input - hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0 - - # Realistic top-k - topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE) - for i in range(NUM_TOKENS): - experts_perm = torch.randperm(NUM_EXPERTS)[:TOP_K] - topk_ids[i] = experts_perm - topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K - - # ---- Reference (BF16) ---- - print("\n[3/6] Running BF16 reference pipeline...") - ref = reference_moe_bf16(hidden_states, nvfp4_tensors, LAYER_IDX, expert_indices, topk_ids, topk_weights, SWIGLU_LIMIT) - print(f" L1 samples: {len(ref['l1_out_all'])} experts with tokens") - if ref['l1_out_all']: - _, _, l1 = ref['l1_out_all'][0] - print(f" L1 out[0]: amax={l1.amax().item():.4f} mean={l1.mean().item():.4f}") - if ref['activated_all']: - _, _, act = ref['activated_all'][0] - print(f" activated[0]: amax={act.amax().item():.4f} mean={act.mean().item():.4f}") - if ref['l2_out_all']: - _, _, l2 = ref['l2_out_all'][0] - print(f" L2 out[0]: amax={l2.amax().item():.4f} mean={l2.mean().item():.4f}") - print(f" Final: amax={ref['output'].amax().item():.4f} mean={ref['output'].mean().item():.4f}") - - # ---- CuTeDSL Runner ---- - print("\n[4/6] Creating CuTeDSL runner...") +def test_stage6_full_runner(hidden_states, topk_weights, topk_ids, weights, expert_indices): + """Stage 6: Full CuTeDSL runner end-to-end.""" + print("\n--- Stage 6: Full CuTeDSL Runner ---") from vllm.nvfp4_cutedsl import CuTeDSLMoERunner - + runner = CuTeDSLMoERunner( num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS, @@ -306,38 +351,150 @@ def main(): runner.l2_sf = weights['l2_sf'] runner.l2_gs = weights['l2_gs'] runner.set_swiglu_limit(SWIGLU_LIMIT) - - print("\n[5/6] Running CuTeDSL runner (with warmup gs)...") + with torch.no_grad(): runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids) print(f" Warmup gs: L1={runner._l1_activation_global_scale:.6f} L2={runner._l2_activation_global_scale:.6f}") runner_out = runner.run(hidden_states, topk_weights, topk_ids) - + print(f" Runner: amax={runner_out.amax().item():.4f} mean={runner_out.mean().item():.4f}") print(f" NaN: {torch.isnan(runner_out).any().item()} Inf: {torch.isinf(runner_out).any().item()}") + return runner_out - # ---- Comparison ---- - print("\n[6/6] Comparing runner vs BF16 reference...") - ref_out = ref['output'] - cos = F.cosine_similarity(ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0)).item() - mse = (ref_out - runner_out).pow(2).mean().item() - print(f" Cosine: {cos:.6f} MSE: {mse:.6e}") - # Per-token - low_cos = 0 +# ============================================================ +# Main +# ============================================================ +def main(): + torch.cuda.set_device(0) + torch.manual_seed(42) + + print(f"=== Step-by-Step Pipeline Test ===") + print(f" Experts: {NUM_EXPERTS}, H={HIDDEN_SIZE}, I={INTERMEDIATE_SIZE}") + print(f" Tokens: {NUM_TOKENS}, top_k={TOP_K}, swiglu_limit={SWIGLU_LIMIT}") + print(f" Stages: {STAGE_START}-{STAGE_END}") + + # Load weights + print("\nLoading checkpoint...") + nvfp4_tensors = load_layer_tensors(MODEL_PATH, LAYER_IDX) + print(f" {len(nvfp4_tensors)} tensors loaded") + + expert_indices = list(range(NUM_EXPERTS)) + + print("Preparing NVFP4 weights...") + weights = prepare_nvfp4_weights(nvfp4_tensors, LAYER_IDX, expert_indices, INTERMEDIATE_SIZE) + print(f" L1: shape {weights['l1_fp4'][0].shape}") + print(f" L2: shape {weights['l2_fp4'][0].shape}") + + # Create input + hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0 + topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE) for i in range(NUM_TOKENS): - cos_i = F.cosine_similarity(ref_out[i].unsqueeze(0), runner_out[i].unsqueeze(0)).item() - if cos_i < 0.95: - low_cos += 1 - if low_cos <= 5: - print(f" Token {i}: cosine={cos_i:.4f}") + experts_perm = torch.randperm(NUM_EXPERTS)[:TOP_K] + topk_ids[i] = experts_perm + topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K - if cos >= 0.98: - print(f"\n✅ PASS: cosine {cos:.6f}") - elif cos >= 0.90: - print(f"\n⚠️ MARGINAL: cosine {cos:.6f}") - else: - print(f"\n❌ FAIL: cosine {cos:.6f}") + # Run stages + sort_data = None + l1_out = None + activated = None + l2_out = None + pipeline_out = None + + if STAGE_START <= 1: + sort_data = test_stage1_sort(hidden_states, topk_ids, topk_weights, expert_indices) + + if STAGE_START <= 2 and STAGE_END >= 2: + if sort_data is None: + print("\n[Stage 2] Skipped (need stage 1 data)") + else: + l1_out, ref_l1_parts, l1_gs_dyn = test_stage2_l1_gemm( + sort_data['slot_hidden'], sort_data['expert_offsets'], + nvfp4_tensors, LAYER_IDX, expert_indices, weights + ) + + if STAGE_START <= 3 and STAGE_END >= 3: + if l1_out is None: + print("\n[Stage 3] Skipped (need stage 2 data)") + else: + activated = test_stage3_swiglu(l1_out, ref_l1_parts, SWIGLU_LIMIT) + + if STAGE_START <= 4 and STAGE_END >= 4: + if activated is None or sort_data is None: + print("\n[Stage 4] Skipped (need stages 2-3 data)") + else: + l2_out, l2_gs_dyn = test_stage4_l2_gemm( + activated, sort_data['expert_offsets'], + nvfp4_tensors, LAYER_IDX, expert_indices, weights + ) + + if STAGE_START <= 5 and STAGE_END >= 5: + if l2_out is None or sort_data is None: + print("\n[Stage 5] Skipped (need stages 1-4 data)") + else: + pipeline_out = test_stage5_scatter( + l2_out, sort_data['expert_offsets'], + sort_data['sorted_token_ids'], sort_data['sorted_weights'], NUM_TOKENS + ) + + if STAGE_START <= 6 and STAGE_END >= 6: + runner_out = test_stage6_full_runner(hidden_states, topk_weights, topk_ids, weights, expert_indices) + + # Compare against pipeline reference + if pipeline_out is not None: + cos = F.cosine_similarity(pipeline_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0)).item() + print(f"\n Pipeline vs Runner: cosine={cos:.6f}") + + # Also compare against full BF16 reference + print("\n Full BF16 reference for comparison...") + ref_out = torch.zeros(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) + for i, e in enumerate(expert_indices): + gate_w = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) + up_w = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) + gate_sf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) + up_sf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) + gate_gs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() + up_gs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale_2"].item() + gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf, gate_gs) + up_bf16 = dequantize_nvfp4_weight(up_w, up_sf, up_gs) + + down_key = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight" + if down_key in nvfp4_tensors: + down_w = nvfp4_tensors[down_key].to(DEVICE) + down_sf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) + down_gs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale_2"].item() + down_bf16 = dequantize_nvfp4_weight(down_w, down_sf, down_gs) + else: + down_bf16 = None + + for t in range(NUM_TOKENS): + for k in range(TOP_K): + eid = topk_ids[t, k].item() + if eid != i: + continue + w = topk_weights[t, k].item() + x = hidden_states[t] + gate = x @ gate_bf16.T + up = x @ up_bf16.T + gate_silu = F.silu(gate).clamp(max=SWIGLU_LIMIT) if SWIGLU_LIMIT else F.silu(gate) + if SWIGLU_LIMIT: + up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT) + act = gate_silu * up + if down_bf16 is not None: + y = act @ down_bf16.T + else: + y = act[:HIDDEN_SIZE] + ref_out[t] += w * y + + cos = F.cosine_similarity(ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0)).item() + mse = (ref_out - runner_out).pow(2).mean().item() + print(f"\n Runner vs BF16: cosine={cos:.6f} MSE={mse:.6e}") + if cos >= 0.98: + print(f" ✅ PASS") + elif cos >= 0.90: + print(f" ⚠️ MARGINAL") + else: + print(f" ❌ FAIL") if __name__ == "__main__":