From 4d0b6d889da0e02b78020d59a1dbba449cbe2c2d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 21:22:50 +0000 Subject: [PATCH] Set runner weights before _ensure_stacked --- tests/test_pipeline_real_weights.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index 71c2f76f..f0de728b 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -172,6 +172,22 @@ def main(): intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS, top_k=TOP_K, device=DEVICE, ) + runner.l1_fp4 = l1_fp4; runner.l1_sf = l1_sf; runner.l1_gs = l1_gs_list + # Set L2 weights too (needed for _ensure_stacked) + l2_fp4, l2_sf, l2_gs_list = [], [], [] + for e in expert_indices: + dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight" + if dk in nvfp4_tensors: + dw = nvfp4_tensors[dk].to(DEVICE) + dsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) + dgs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale_2"].item() + l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()) + l2_sf.append(dsf.permute(1,0).contiguous()); l2_gs_list.append(dgs) + else: + l2_fp4.append(torch.zeros(INTERMEDIATE_SIZE//2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) + l2_sf.append(torch.ones(INTERMEDIATE_SIZE//16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE)) + l2_gs_list.append(1.0) + runner.l2_fp4 = l2_fp4; runner.l2_sf = l2_sf; runner.l2_gs = l2_gs_list runner._ensure_stacked() # Just use the runner's scale assembly l1_gsa = torch.full((NUM_EXPERTS,), l1_gs, dtype=torch.float32, device=DEVICE)