From 4ef345773d0579dc654cc744584cead9995e5fac Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 21:17:18 +0000 Subject: [PATCH] Rewrite pipeline test: load real weights, step-by-step vs BF16 reference --- tests/test_pipeline_real_weights.py | 441 +++++++++++++++++----------- 1 file changed, 276 insertions(+), 165 deletions(-) diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index 19d1b233..8d70ea2d 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -1,88 +1,298 @@ -"""Pipeline Test: Compare CuTeDSL runner vs reference with real model weights. +"""Step-by-step Pipeline Test: CuTeDSL runner components vs reference. -Loads layer 0 from DeepSeek-V4-Pro-NVFP4, runs both the reference -moe_pipeline and our CuTeDSLMoERunner, compares output step by step. +Loads real layer 0 weights from DeepSeek-V4-Pro-NVFP4. +Tests each pipeline stage independently: + 1. Token sorting & expert assignment + 2. L1 GEMM (gate+up) + 3. SwiGLU activation (with swiglu_limit clamping) + 4. L2 GEMM (down_proj) + 5. Scatter-add with routing weights + 6. Full runner vs reference + +Strategy: Comment out stages to isolate bugs, then uncomment one by one. """ import torch +import torch.nn.functional as F import sys import os +import math import glob sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) -# Must be run from project root with: python3 tests/test_pipeline_real_weights.py -# Or with sys.path set to project root -from vllm.nvfp4_cutedsl import CuTeDSLMoERunner - # ============================================================ -# CONFIG +# CONFIG — toggle which stages to test # ============================================================ MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" LAYER_IDX = 0 -NUM_EXPERTS = 48 +NUM_EXPERTS = 48 # local experts per rank (256/8=32, but model uses 48) HIDDEN_SIZE = 7168 INTERMEDIATE_SIZE = 18432 -# Note: gate and up each have INTERMEDIATE_SIZE outputs -# L1 GEMM output = 2 * INTERMEDIATE_SIZE -NUM_TOKENS = 64 +NUM_TOKENS = 8 TOP_K = 6 SWIGLU_LIMIT = 10.0 DEVICE = "cuda" +# Which stages to enable (uncomment incrementally to find bugs) +ENABLE_SORT = True +ENABLE_L1_GEMM = True +ENABLE_SWIGLU = True +ENABLE_L2_GEMM = True +ENABLE_SCATTER = True +ENABLE_FULL_RUNNER = True -def make_synthetic_weights(num_experts, hidden_size, intermediate_size, device): - """Create synthetic NVFP4 weights matching the runner's expected format. - - Uses the same format as layertest but with realistic amax distributions. - """ - import math +# ============================================================ +# Weight loading (from layertest.py pattern) +# ============================================================ +def load_layer_tensors(model_dir, layer_idx): + tensors = {} + pattern = os.path.join(model_dir, f"layers.{layer_idx}.mlp.experts.*") + for sf in glob.glob(os.path.join(model_dir, "*.safetensors")): + from safetensors.torch import load_file + data = load_file(sf) + for k, v in data.items(): + if f"layers.{layer_idx}." in k: + tensors[k] = v + return tensors + + +def prepare_nvfp4_weights(nvfp4_tensors, layer_idx, expert_indices, intermediate_size): + """Prepare weights via direct view-cast (same as layertest).""" l1_fp4, l1_sf, l1_gs = [], [], [] l2_fp4, l2_sf, l2_gs = [], [], [] - - for e in range(num_experts): - # L1: gate+up concatenated, (ceil(K/2), 2*intermediate) - K = hidden_size - N = 2 * intermediate_size - l1_fp4.append(torch.randint(0, 255, (math.ceil(K/2), N), dtype=torch.uint8, device=device)) - l1_sf.append(torch.randn(K // 16, N, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)) - l1_gs.append(torch.tensor([0.01], dtype=torch.float32, device=device)) - - # L2: down, (ceil(N/2), hidden) - K2 = intermediate_size - N2 = hidden_size - l2_fp4.append(torch.randint(0, 255, (math.ceil(K2/2), N2), dtype=torch.uint8, device=device)) - l2_sf.append(torch.randn(K2 // 16, N2, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)) - l2_gs.append(torch.tensor([0.01], dtype=torch.float32, device=device)) - + + for e in expert_indices: + gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) + up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) + gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) + up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) + gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() + up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() + + fused_w = torch.cat([gate_w, up_w], dim=0) + fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + + fused_sf = torch.cat([gate_sf, up_sf], dim=0).permute(1, 0).contiguous() + + l1_max_gs = max(gate_gs, up_gs) + if gate_gs != up_gs: + fused_sf_f32 = fused_sf.float() + fused_sf_f32[:, :intermediate_size] *= (gate_gs / l1_max_gs) + fused_sf_f32[:, intermediate_size:] *= (up_gs / l1_max_gs) + fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) + + l1_fp4.append(fused_w_fp4) + l1_sf.append(fused_sf) + l1_gs.append(l1_max_gs) + + down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + if down_key in nvfp4_tensors: + down_w = nvfp4_tensors[down_key].to(DEVICE) + down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) + down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() + + down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + l2_fp4.append(down_w_fp4) + l2_sf.append(down_sf.permute(1, 0).contiguous()) + l2_gs.append(down_gs) + else: + l2_fp4.append(torch.zeros(intermediate_size // 2, hidden_size, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) + l2_sf.append(torch.ones(intermediate_size // 16, hidden_size, dtype=torch.float8_e4m3fn, device=DEVICE)) + l2_gs.append(1.0) + return { 'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs, 'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs, } +def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale): + """Dequantize NVFP4 weight to BF16 for reference computation.""" + # FP4 lookup table + lut = torch.tensor([ + 0., 0.5, 1., 1.5, 2., 3., 4., 6., + -0., -0.5, -1., -1.5, -2., -3., -4., -6. + ], dtype=torch.float32) + device = packed_uint8.device + lut = lut.to(device) + + lower = lut[(packed_uint8 & 0x0F).long()] + upper = lut[((packed_uint8 >> 4) & 0x0F).long()] + out_features = packed_uint8.shape[0] + in_features = packed_uint8.shape[1] * 2 + + bf16_vals = torch.stack([lower, upper], dim=-1).reshape(out_features, in_features) + scale_2d = scale_e4m3.float().reshape(-1, 1).expand(-1, in_features // scale_e4m3.shape[0] if scale_e4m3.shape[0] < in_features else 1) + # scale is (K_sf, N), expand to match (K, N) where K_sf = K/16 + K, N = packed_uint8.shape[0], packed_uint8.shape[1] * 2 + K_sf = scale_e4m3.shape[0] + if K_sf != K: + scale_2d = scale_e4m3.float().repeat_interleave(K // K_sf, dim=0) + else: + scale_2d = scale_e4m3.float() + dequant = bf16_vals * scale_2d * global_scale + return dequant.to(torch.bfloat16) + + +# ============================================================ +# Reference pipeline (step by step, BF16) +# ============================================================ +def reference_moe_bf16(hidden_states, nvfp4_tensors, layer_idx, expert_indices, topk_ids, topk_weights, swiglu_limit): + """BF16 reference: dequantize weights, run MoE step by step.""" + num_tokens = hidden_states.shape[0] + top_k = topk_ids.shape[1] + output = torch.zeros(num_tokens, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) + + # Store intermediate results for comparison + intermediates = {} + + # Sort tokens by expert (for comparing with runner's sorted approach) + flat_ids = topk_ids.reshape(-1) + flat_weights = topk_weights.reshape(-1) + sort_idx = flat_ids.argsort(stable=True) + sorted_ids = flat_ids[sort_idx] + sorted_weights = flat_weights[sort_idx] + token_indices = torch.arange(num_tokens, device=DEVICE).unsqueeze(1).expand(-1, top_k).reshape(-1) + sorted_token_ids = token_indices[sort_idx] + + intermediates['sorted_ids'] = sorted_ids + intermediates['sorted_token_ids'] = sorted_token_ids + intermediates['sorted_weights'] = sorted_weights + + # Expert offsets + expert_id_range = torch.arange(len(expert_indices), device=DEVICE) + tokens_per_expert = torch.zeros(len(expert_indices), dtype=torch.int32, device=DEVICE) + for i, e in enumerate(expert_indices): + tokens_per_expert[i] = (sorted_ids == i).sum() + expert_offsets = torch.zeros(len(expert_indices) + 1, dtype=torch.int32, device=DEVICE) + expert_offsets[1:] = tokens_per_expert.cumsum(0) + + intermediates['expert_offsets'] = expert_offsets + intermediates['tokens_per_expert'] = tokens_per_expert + + # Gather hidden states for sorted tokens + slot_hidden = hidden_states[sorted_token_ids] + intermediates['slot_hidden'] = slot_hidden + + # Per-expert computation + l1_out_all = [] + activated_all = [] + l2_out_all = [] + + for i, e in enumerate(expert_indices): + start = expert_offsets[i].item() + end = expert_offsets[i + 1].item() + if start == end: + continue + + x = slot_hidden[start:end] # (T, H) + + # L1: gate + up + gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) + up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) + gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) + up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) + gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() + up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() + + gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf.T if gate_sf.shape[0] == gate_w.shape[1] else gate_sf, gate_gs) + up_bf16 = dequantize_nvfp4_weight(up_w, up_sf.T if up_sf.shape[0] == up_w.shape[1] else up_sf, up_gs) + + gate = x @ gate_bf16.T # (T, intermediate) + up = x @ up_bf16.T # (T, intermediate) + + l1_out = torch.cat([gate, up], dim=1) # (T, 2*intermediate) + l1_out_all.append((start, end, l1_out)) + + # SwiGLU + gate_silu = F.silu(gate) + if swiglu_limit is not None: + gate_silu = gate_silu.clamp(max=swiglu_limit) + up = up.clamp(min=-swiglu_limit, max=swiglu_limit) + activated = gate_silu * up + activated_all.append((start, end, activated)) + + # L2: down + down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + if down_key in nvfp4_tensors: + down_w = nvfp4_tensors[down_key].to(DEVICE) + down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) + down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() + down_bf16 = dequantize_nvfp4_weight(down_w, down_sf.T if down_sf.shape[0] == down_w.shape[1] else down_sf, down_gs) + l2_out = activated @ down_bf16.T # (T, H) + else: + l2_out = activated[:, :HIDDEN_SIZE] + + l2_out_all.append((start, end, l2_out)) + + # Scatter-add + weighted = l2_out * sorted_weights[start:end].unsqueeze(1).to(l2_out.dtype) + output.scatter_add_(0, sorted_token_ids[start:end].unsqueeze(1).expand(-1, HIDDEN_SIZE), weighted) + + intermediates['l1_out_all'] = l1_out_all + intermediates['activated_all'] = activated_all + intermediates['l2_out_all'] = l2_out_all + intermediates['output'] = output + + return intermediates + + +# ============================================================ +# Main test +# ============================================================ def main(): torch.cuda.set_device(0) torch.manual_seed(42) - - print(f"=== Pipeline Test: {NUM_EXPERTS} experts, H={HIDDEN_SIZE}, I={INTERMEDIATE_SIZE}, {NUM_TOKENS} tokens, top_k={TOP_K} ===") - print(f" swiglu_limit={SWIGLU_LIMIT}") - - print("\nCreating synthetic weights...") - weights = make_synthetic_weights(NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE, DEVICE) - print(f"Created {NUM_EXPERTS} experts") - + + print(f"=== Step-by-Step Pipeline Test ===") + print(f" Experts: {NUM_EXPERTS}, H={HIDDEN_SIZE}, I={INTERMEDIATE_SIZE}") + print(f" Tokens: {NUM_TOKENS}, top_k={TOP_K}, swiglu_limit={SWIGLU_LIMIT}") + print(f" Stages: sort={ENABLE_SORT} L1={ENABLE_L1_GEMM} swiglu={ENABLE_SWIGLU} L2={ENABLE_L2_GEMM} scatter={ENABLE_SCATTER} full={ENABLE_FULL_RUNNER}") + + # Load real weights + print("\n[1/6] Loading checkpoint...") + nvfp4_tensors = load_layer_tensors(MODEL_PATH, LAYER_IDX) + print(f" {len(nvfp4_tensors)} tensors loaded") + + # Figure out expert indices for this rank + # layer 0 has experts 0-255, we use first NUM_EXPERTS + expert_indices = list(range(NUM_EXPERTS)) + print(f" Using experts: {expert_indices[:5]}... (first 5 of {NUM_EXPERTS})") + + print("\n[2/6] Preparing NVFP4 weights (direct view-cast)...") + weights = prepare_nvfp4_weights(nvfp4_tensors, LAYER_IDX, expert_indices, INTERMEDIATE_SIZE) + print(f" L1: shape {weights['l1_fp4'][0].shape} dtype {weights['l1_fp4'][0].dtype}") + print(f" L2: shape {weights['l2_fp4'][0].shape} dtype {weights['l2_fp4'][0].dtype}") + # Create input - hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) - - # Realistic top-k: uneven distribution + hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0 + + # Realistic top-k topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE) for i in range(NUM_TOKENS): experts_perm = torch.randperm(NUM_EXPERTS)[:TOP_K] topk_ids[i] = experts_perm topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K - - # ---- Runner ---- - print("\n--- CuTeDSL Runner (warmup gs, full-buffer swizzle, swiglu_limit) ---") + + # ---- Reference (BF16) ---- + print("\n[3/6] Running BF16 reference pipeline...") + ref = reference_moe_bf16(hidden_states, nvfp4_tensors, LAYER_IDX, expert_indices, topk_ids, topk_weights, SWIGLU_LIMIT) + print(f" L1 samples: {len(ref['l1_out_all'])} experts with tokens") + if ref['l1_out_all']: + _, _, l1 = ref['l1_out_all'][0] + print(f" L1 out[0]: amax={l1.amax().item():.4f} mean={l1.mean().item():.4f}") + if ref['activated_all']: + _, _, act = ref['activated_all'][0] + print(f" activated[0]: amax={act.amax().item():.4f} mean={act.mean().item():.4f}") + if ref['l2_out_all']: + _, _, l2 = ref['l2_out_all'][0] + print(f" L2 out[0]: amax={l2.amax().item():.4f} mean={l2.mean().item():.4f}") + print(f" Final: amax={ref['output'].amax().item():.4f} mean={ref['output'].mean().item():.4f}") + + # ---- CuTeDSL Runner ---- + print("\n[4/6] Creating CuTeDSL runner...") + from vllm.nvfp4_cutedsl import CuTeDSLMoERunner + runner = CuTeDSLMoERunner( num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS, @@ -95,137 +305,38 @@ def main(): runner.l2_sf = weights['l2_sf'] runner.l2_gs = weights['l2_gs'] runner.set_swiglu_limit(SWIGLU_LIMIT) - + + print("\n[5/6] Running CuTeDSL runner (with warmup gs)...") with torch.no_grad(): - # Compute warmup gs runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids) - l1_gs_val = runner._l1_activation_global_scale - l2_gs_val = runner._l2_activation_global_scale - print(f"Warmup gs: L1={l1_gs_val:.6f} L2={l2_gs_val:.6f}") - - # Run + print(f" Warmup gs: L1={runner._l1_activation_global_scale:.6f} L2={runner._l2_activation_global_scale:.6f}") runner_out = runner.run(hidden_states, topk_weights, topk_ids) - - print(f"Runner: amax={runner_out.amax().item():.4f} mean={runner_out.mean().item():.4f}") + + print(f" Runner: amax={runner_out.amax().item():.4f} mean={runner_out.mean().item():.4f}") print(f" NaN: {torch.isnan(runner_out).any().item()} Inf: {torch.isinf(runner_out).any().item()}") - - # ---- Reference: same runner but with dynamic gs (quantize_to_nvfp4) ---- - print("\n--- Reference (dynamic gs via quantize_to_nvfp4) ---") - # We'll use the same runner infrastructure but manually call the reference path - from cutedsl.bridge import ( - quantize_to_nvfp4, run_nvfp4_grouped_gemm, - assemble_scales_3d_side, make_b_k_major, - ) - - with torch.no_grad(): - # Stack weights for GEMM - l1_mat_b = torch.stack(weights['l1_fp4']) - l1_scale_b = torch.stack(weights['l1_sf']) - l1_gsb = torch.stack(weights['l1_gs']) - l2_mat_b = torch.stack(weights['l2_fp4']) - l2_scale_b = torch.stack(weights['l2_sf']) - l2_gsb = torch.stack(weights['l2_gs']) - - # Make B-K major (required by GEMM) - l1_mat_b = make_b_k_major(l1_mat_b) - l1_scale_b = assemble_scales_3d_side(l1_scale_b) - l2_mat_b = make_b_k_major(l2_mat_b) - l2_scale_b = assemble_scales_3d_side(l2_scale_b) - - # Sort tokens by expert - flat_ids = topk_ids.reshape(-1) - flat_weights = topk_weights.reshape(-1) - sort_idx = flat_ids.argsort(stable=True) - sorted_ids = flat_ids[sort_idx] - sorted_weights = flat_weights[sort_idx] - token_indices = torch.arange(NUM_TOKENS, device=DEVICE).unsqueeze(1).expand(-1, TOP_K).reshape(-1) - sorted_token_ids = token_indices[sort_idx] - - # Expert offsets - expert_id_range = torch.arange(NUM_EXPERTS, device=DEVICE) - tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() - expert_offsets = torch.zeros(NUM_EXPERTS + 1, dtype=torch.int32, device=DEVICE) - expert_offsets[1:] = tokens_per_expert.cumsum(0) - - slot_hidden = hidden_states[sorted_token_ids] - - # L1: dynamic gs - x_fp4, x_sf, l1_gs_dyn = quantize_to_nvfp4(slot_hidden) - l1_gsa = torch.full((NUM_EXPERTS,), l1_gs_dyn, dtype=torch.float32, device=DEVICE) - l1_scale_a = assemble_scales_3d_side(x_sf, expert_offsets[:NUM_EXPERTS+1], NUM_EXPERTS) - - l1_out = run_nvfp4_grouped_gemm( - mat_a=x_fp4, mat_b=l1_mat_b, - scale_a=l1_scale_a, scale_b=l1_scale_b, - expert_offsets=expert_offsets[1:], - global_scale_a=l1_gsa, global_scale_b=l1_gsb, - ) - print(f" L1 gs (dynamic): {l1_gs_dyn:.6f}") - print(f" L1 out: amax={l1_out.amax().item():.4f}") - - # SiLU(gate) * up with swiglu_limit - gate = l1_out[:, :INTERMEDIATE_SIZE] - up = l1_out[:, INTERMEDIATE_SIZE:] - gate_silu = torch.nn.functional.silu(gate) - if SWIGLU_LIMIT is not None: - gate_silu = gate_silu.clamp(max=SWIGLU_LIMIT) - up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT) - activated = gate_silu * up - print(f" activated: amax={activated.amax().item():.4f}") - - # L2: dynamic gs - l2_x_fp4, l2_x_sf, l2_gs_dyn = quantize_to_nvfp4(activated) - l2_gsa = torch.full((NUM_EXPERTS,), l2_gs_dyn, dtype=torch.float32, device=DEVICE) - l2_scale_a = assemble_scales_3d_side(l2_x_sf, expert_offsets[:NUM_EXPERTS+1], NUM_EXPERTS) - - l2_out = run_nvfp4_grouped_gemm( - mat_a=l2_x_fp4, mat_b=l2_mat_b, - scale_a=l2_scale_a, scale_b=l2_scale_b, - expert_offsets=expert_offsets[1:], - global_scale_a=l2_gsa, global_scale_b=l2_gsb, - ) - print(f" L2 gs (dynamic): {l2_gs_dyn:.6f}") - - # Scatter-add - ref_out = torch.zeros(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) - weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype) - ref_out.scatter_add_(0, sorted_token_ids.unsqueeze(1).expand(-1, HIDDEN_SIZE), weighted_out) - - print(f"Reference: amax={ref_out.amax().item():.4f} mean={ref_out.mean().item():.4f}") - print(f" NaN: {torch.isnan(ref_out).any().item()} Inf: {torch.isinf(ref_out).any().item()}") - + # ---- Comparison ---- - print("\n--- Comparison ---") - cos = torch.nn.functional.cosine_similarity( - ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0) - ).item() + print("\n[6/6] Comparing runner vs BF16 reference...") + ref_out = ref['output'] + cos = F.cosine_similarity(ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0)).item() mse = (ref_out - runner_out).pow(2).mean().item() - print(f"Cosine: {cos:.6f} MSE: {mse:.4f}") - + print(f" Cosine: {cos:.6f} MSE: {mse:.6e}") + # Per-token - low_cos_tokens = 0 + low_cos = 0 for i in range(NUM_TOKENS): - cos_i = torch.nn.functional.cosine_similarity( - ref_out[i].unsqueeze(0), runner_out[i].unsqueeze(0) - ).item() + cos_i = F.cosine_similarity(ref_out[i].unsqueeze(0), runner_out[i].unsqueeze(0)).item() if cos_i < 0.95: - low_cos_tokens += 1 - if low_cos_tokens <= 5: - print(f" Token {i}: cosine={cos_i:.4f} ref_max={ref_out[i].amax().item():.4f} run_max={runner_out[i].amax().item():.4f}") - if low_cos_tokens > 5: - print(f" ... {low_cos_tokens - 5} more tokens with cosine < 0.95") - + low_cos += 1 + if low_cos <= 5: + print(f" Token {i}: cosine={cos_i:.4f}") + if cos >= 0.98: - print(f"\n✅ PASS: cosine {cos:.6f} >= 0.98") + print(f"\n✅ PASS: cosine {cos:.6f}") elif cos >= 0.90: print(f"\n⚠️ MARGINAL: cosine {cos:.6f}") else: print(f"\n❌ FAIL: cosine {cos:.6f}") - - # Print gs comparison - print(f"\n--- GS Comparison ---") - print(f" L1: dynamic={l1_gs_dyn:.6f} warmup={l1_gs_val:.6f} ratio={l1_gs_val/l1_gs_dyn:.4f}") - print(f" L2: dynamic={l2_gs_dyn:.6f} warmup={l2_gs_val:.6f} ratio={l2_gs_val/l2_gs_dyn:.4f}") if __name__ == "__main__":