diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index 71c2f76f..f0de728b 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -172,6 +172,22 @@ def main(): intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS, top_k=TOP_K, device=DEVICE, ) + runner.l1_fp4 = l1_fp4; runner.l1_sf = l1_sf; runner.l1_gs = l1_gs_list + # Set L2 weights too (needed for _ensure_stacked) + l2_fp4, l2_sf, l2_gs_list = [], [], [] + for e in expert_indices: + dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight" + if dk in nvfp4_tensors: + dw = nvfp4_tensors[dk].to(DEVICE) + dsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) + dgs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale_2"].item() + l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()) + l2_sf.append(dsf.permute(1,0).contiguous()); l2_gs_list.append(dgs) + else: + l2_fp4.append(torch.zeros(INTERMEDIATE_SIZE//2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) + l2_sf.append(torch.ones(INTERMEDIATE_SIZE//16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE)) + l2_gs_list.append(1.0) + runner.l2_fp4 = l2_fp4; runner.l2_sf = l2_sf; runner.l2_gs = l2_gs_list runner._ensure_stacked() # Just use the runner's scale assembly l1_gsa = torch.full((NUM_EXPERTS,), l1_gs, dtype=torch.float32, device=DEVICE)