diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index aaba97d7..30687b8d 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -1,6 +1,4 @@ -"""Debug test: Replicate runner logic step by step in Python. -Compare against BF16 reference to isolate where tokens get dropped. -""" +"""Full pipeline test: Fixed runner vs BF16 reference.""" import torch import torch.nn.functional as F import sys, os, glob @@ -45,8 +43,9 @@ def main(): torch.cuda.set_device(0) torch.manual_seed(42) - print("=== Runner Logic Debug ===") + print("=== Full Pipeline Test (Fixed Runner) ===") nvfp4_tensors = load_layer_tensors(MODEL_PATH, LAYER_IDX) + expert_indices = list(range(NUM_EXPERTS)) hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0 topk_ids = torch.zeros(NUM_TOKENS, TOP_K, dtype=torch.int64, device=DEVICE) @@ -54,89 +53,48 @@ def main(): topk_ids[i] = torch.randperm(NUM_EXPERTS)[:TOP_K] topk_weights = torch.ones(NUM_TOKENS, TOP_K, dtype=torch.float32, device=DEVICE) / TOP_K - # Step 1: Global→local remap (same as runner) - experts_start_idx = 0 - local_ids = topk_ids - experts_start_idx - local_mask = (local_ids >= 0) & (local_ids < NUM_EXPERTS) - safe_ids = local_ids.clamp(0, NUM_EXPERTS - 1) - safe_weights = topk_weights * local_mask.float() + # BF16 reference + ref_out = torch.zeros(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) + for i, e in enumerate(expert_indices): + dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight" + gk = f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight" + uk = f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight" + if dk not in nvfp4_tensors: + continue + gate_bf16 = dequantize_nvfp4_weight( + nvfp4_tensors[gk].to(DEVICE), + nvfp4_tensors[gk.replace('.weight', '.weight_scale')].to(DEVICE), + nvfp4_tensors[gk.replace('.weight', '.weight_scale_2')].item()) + up_bf16 = dequantize_nvfp4_weight( + nvfp4_tensors[uk].to(DEVICE), + nvfp4_tensors[uk.replace('.weight', '.weight_scale')].to(DEVICE), + nvfp4_tensors[uk.replace('.weight', '.weight_scale_2')].item()) + down_bf16 = dequantize_nvfp4_weight( + nvfp4_tensors[dk].to(DEVICE), + nvfp4_tensors[dk.replace('.weight', '.weight_scale')].to(DEVICE), + nvfp4_tensors[dk.replace('.weight', '.weight_scale_2')].item()) + + for t in range(NUM_TOKENS): + for k in range(TOP_K): + if topk_ids[t, k].item() != i: + continue + w = topk_weights[t, k].item() + x = hidden_states[t] + gate = x @ gate_bf16.T + up = x @ up_bf16.T + gate_silu = F.silu(gate).clamp(max=SWIGLU_LIMIT) + up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT) + act = gate_silu * up + ref_out[t] += w * (act @ down_bf16.T) - print(f"topk_ids:\n{topk_ids}") - print(f"local_ids:\n{local_ids}") - print(f"local_mask:\n{local_mask}") - print(f"safe_weights (should all be 0.1667):\n{safe_weights}") + print(f"BF16 ref: amax={ref_out.amax().item():.4f}") - # Step 2: Sort by expert - flat_ids = safe_ids.reshape(-1) - flat_weights = safe_weights.reshape(-1) - num_slots = NUM_TOKENS * TOP_K - token_indices = torch.arange(NUM_TOKENS, device=DEVICE).unsqueeze(1).expand(-1, TOP_K).reshape(-1) + # CuTeDSL runner + from vllm.nvfp4_cutedsl import CuTeDSLMoERunner + from cutedsl.bridge import assemble_scales_3d_side, make_b_k_major - sort_idx = flat_ids.argsort(stable=True) - sorted_ids = flat_ids[sort_idx] - sorted_weights = flat_weights[sort_idx] - sorted_token_ids = token_indices[sort_idx] - - print(f"\nsorted_ids: {sorted_ids.tolist()}") - print(f"sorted_token_ids: {sorted_token_ids.tolist()}") - print(f"sorted_weights: {sorted_weights.tolist()}") - - # Step 3: Expert offsets - expert_id_range = torch.arange(NUM_EXPERTS, device=DEVICE) - tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() - expert_offsets = torch.zeros(NUM_EXPERTS + 1, dtype=torch.int32, device=DEVICE) - expert_offsets[1:] = tokens_per_expert.cumsum(0) - - print(f"\ntokens_per_expert: {tokens_per_expert.tolist()}") - print(f"expert_offsets: {expert_offsets.tolist()}") - - # Step 4: Padded offsets - padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128 - padded_expert_offsets = torch.zeros(NUM_EXPERTS + 1, dtype=torch.int32, device=DEVICE) - padded_expert_offsets[1:] = padded_tokens_per_expert.cumsum(0) - total_padded = padded_expert_offsets[NUM_EXPERTS].item() - - print(f"padded_tokens_per_expert: {padded_tokens_per_expert.tolist()}") - print(f"padded_expert_offsets: {padded_expert_offsets.tolist()}") - print(f"total_padded: {total_padded}") - - # Step 5: Scatter into padded layout (runner's searchsorted approach) - row_indices = torch.arange(num_slots, device=DEVICE) - expert_assign = torch.searchsorted(expert_offsets[1:], row_indices, right=True).clamp(max=NUM_EXPERTS - 1) - local_row = row_indices - expert_offsets[expert_assign] - padded_dst = padded_expert_offsets[expert_assign] + local_row - - print(f"\nexpert_assign: {expert_assign.tolist()}") - print(f"local_row: {local_row.tolist()}") - print(f"padded_dst: {padded_dst.tolist()}") - - # Verify: expert_assign should match sorted_ids - match = (expert_assign == sorted_ids).all().item() - print(f"expert_assign == sorted_ids: {match}") - if not match: - mismatches = (expert_assign != sorted_ids).nonzero().squeeze() - print(f" Mismatch at rows: {mismatches.tolist()}") - print(f" expert_assign[mismatch]: {expert_assign[mismatches].tolist()}") - print(f" sorted_ids[mismatch]: {sorted_ids[mismatches].tolist()}") - - # Step 6: Scatter hidden states - slot_hidden = hidden_states[sorted_token_ids] - padded_hidden = torch.zeros(total_padded, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) - padded_hidden[padded_dst] = slot_hidden - - # Verify: padded_hidden[padded_dst] should match slot_hidden - verify = (padded_hidden[padded_dst] == slot_hidden).all().item() - print(f"\npadded_hidden scatter correct: {verify}") - - # Step 7: Now run L1 GEMM using bridge (direct call, not runner) - from cutedsl.bridge import ( - quantize_to_nvfp4, run_nvfp4_grouped_gemm, - assemble_scales_3d_side, make_b_k_major, - ) - - # Prepare weights (same as runner's _ensure_stacked) - expert_indices = list(range(NUM_EXPERTS)) - l1_fp4, l1_sf, l1_gs_list = [], [], [] + l1_fp4, l1_sf, l1_gs = [], [], [] + l2_fp4, l2_sf, l2_gs = [], [], [] for e in expert_indices: gw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) uw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) @@ -148,106 +106,54 @@ def main(): fsf = torch.cat([gsf, usf], dim=0).permute(1,0).contiguous() mgs = max(ggs, ugs) if ggs != ugs: - fsf32 = fsf.float() - fsf32[:, :INTERMEDIATE_SIZE] *= (ggs / mgs) - fsf32[:, INTERMEDIATE_SIZE:] *= (ugs / mgs) - fsf = fsf32.to(torch.float8_e4m3fn) - l1_fp4.append(fw); l1_sf.append(fsf); l1_gs_list.append(mgs) - - l1_mat_b = torch.stack(l1_fp4) - l1_mat_b = make_b_k_major(l1_mat_b) - l1_scale_b = assemble_scales_3d_side(l1_sf) - l1_gsb = torch.tensor(l1_gs_list, dtype=torch.float32, device=DEVICE) - - # Quantize activation (dynamic gs, not warmup) - # KEY FIX: quantize slot_hidden (sorted tokens), NOT padded_hidden. - # padded_hidden has zeros in padding rows; quantizing it gives wrong x_sf layout. - print("\n--- L1 GEMM (dynamic gs) ---") - slot_x_fp4, slot_x_sf, l1_gs = quantize_to_nvfp4(slot_hidden) - print(f" L1 gs (dynamic): {l1_gs:.6f}") - - # Scatter x_fp4 into padded layout (use uint8 for scatter, then view as float4) - padded_x_fp4_uint8 = torch.zeros(total_padded, HIDDEN_SIZE // 2, dtype=torch.uint8, device=DEVICE) - padded_x_fp4_uint8[padded_dst] = slot_x_fp4.view(torch.uint8) - padded_x_fp4 = padded_x_fp4_uint8.view(torch.float4_e2m1fn_x2) - - # For scale_a, we need to use the runner's assembly approach. - # Use the same _assemble_scales_cudagraph_safe function - from vllm.nvfp4_cutedsl import CuTeDSLMoERunner - runner = CuTeDSLMoERunner( - num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS, - top_k=TOP_K, device=DEVICE, - ) - runner.l1_fp4 = l1_fp4; runner.l1_sf = l1_sf; runner.l1_gs = l1_gs_list - # Set L2 weights too (needed for _ensure_stacked) - l2_fp4, l2_sf, l2_gs_list = [], [], [] - for e in expert_indices: + sf32 = fsf.float() + sf32[:, :INTERMEDIATE_SIZE] *= (ggs / mgs) + sf32[:, INTERMEDIATE_SIZE:] *= (ugs / mgs) + fsf = sf32.to(torch.float8_e4m3fn) + l1_fp4.append(fw); l1_sf.append(fsf); l1_gs.append(mgs) + dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight" if dk in nvfp4_tensors: dw = nvfp4_tensors[dk].to(DEVICE) dsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) dgs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale_2"].item() l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()) - l2_sf.append(dsf.permute(1,0).contiguous()); l2_gs_list.append(dgs) + l2_sf.append(dsf.permute(1,0).contiguous()); l2_gs.append(dgs) else: l2_fp4.append(torch.zeros(INTERMEDIATE_SIZE//2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) l2_sf.append(torch.ones(INTERMEDIATE_SIZE//16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE)) - l2_gs_list.append(1.0) - runner.l2_fp4 = l2_fp4; runner.l2_sf = l2_sf; runner.l2_gs = l2_gs_list - runner._ensure_stacked() - # Just use the runner's scale assembly - l1_gsa = torch.full((NUM_EXPERTS,), l1_gs, dtype=torch.float32, device=DEVICE) - l1_scale_a = runner._assemble_scales_cudagraph_safe( - slot_x_sf, expert_offsets[:NUM_EXPERTS+1], - padded_expert_offsets, - runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1 + l2_gs.append(1.0) + + runner = CuTeDSLMoERunner( + num_experts=NUM_EXPERTS, hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS, + top_k=TOP_K, device=DEVICE, ) + runner.l1_fp4 = l1_fp4; runner.l1_sf = l1_sf; runner.l1_gs = l1_gs + runner.l2_fp4 = l2_fp4; runner.l2_sf = l2_sf; runner.l2_gs = l2_gs + runner.set_swiglu_limit(SWIGLU_LIMIT) - l1_out = run_nvfp4_grouped_gemm( - mat_a=padded_x_fp4, mat_b=l1_mat_b, - scale_a=l1_scale_a, scale_b=l1_scale_b, - expert_offsets=padded_expert_offsets[1:NUM_EXPERTS+1], - global_scale_a=l1_gsa, global_scale_b=l1_gsb, - ) - print(f" L1 out: shape={l1_out.shape} amax={l1_out.amax().item():.4f}") - print(f" L1 out NaN: {torch.isnan(l1_out).any().item()}") + with torch.no_grad(): + runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids) + runner_out = runner.run(hidden_states, topk_weights, topk_ids) - # Extract real tokens - l1_out_real = l1_out[padded_dst] - print(f" L1 real: amax={l1_out_real.amax().item():.4f}") + print(f"Runner: amax={runner_out.amax().item():.4f}") + print(f"NaN: {torch.isnan(runner_out).any().item()}") - # BF16 reference L1 - ref_l1 = torch.zeros(num_slots, 2*INTERMEDIATE_SIZE, dtype=torch.bfloat16, device=DEVICE) - for i, e in enumerate(expert_indices): - start = expert_offsets[i].item() - end = expert_offsets[i+1].item() - if start == end: - continue - x = slot_hidden[start:end] - gw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) - uw = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) - gsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) - usf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) - ggs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() - ugs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.up_proj.weight_scale_2"].item() - gate = x @ dequantize_nvfp4_weight(gw, gsf, ggs).T - up = x @ dequantize_nvfp4_weight(uw, usf, ugs).T - ref_l1[start:end] = torch.cat([gate, up], dim=1) + cos = F.cosine_similarity(ref_out.flatten().unsqueeze(0), runner_out.flatten().unsqueeze(0)).item() + mse = (ref_out - runner_out).pow(2).mean().item() + print(f"\nCosine: {cos:.6f} MSE: {mse:.6e}") - # Compare L1 - cos_l1 = F.cosine_similarity(ref_l1.flatten().unsqueeze(0), l1_out_real.flatten().unsqueeze(0)).item() - print(f"\n L1 cosine vs BF16: {cos_l1:.6f}") + for t in range(NUM_TOKENS): + ct = F.cosine_similarity(ref_out[t].unsqueeze(0), runner_out[t].unsqueeze(0)).item() + print(f" Token {t}: cosine={ct:.4f}") - # Per-expert L1 comparison - for i in list(range(NUM_EXPERTS))[:5]: - start = expert_offsets[i].item() - end = expert_offsets[i+1].item() - if start == end: - continue - c = F.cosine_similarity(ref_l1[start:end].flatten().unsqueeze(0), - l1_out_real[start:end].flatten().unsqueeze(0)).item() - print(f" Expert {i} L1: cosine={c:.6f} ref_amax={ref_l1[start:end].amax().item():.4f} run_amax={l1_out_real[start:end].amax().item():.4f}") + if cos >= 0.98: + print(f"\n✅ PASS") + elif cos >= 0.90: + print(f"\n⚠️ MARGINAL") + else: + print(f"\n❌ FAIL") if __name__ == "__main__":