Set runner weights before _ensure_stacked

This commit is contained in:
2026-05-17 21:22:50 +00:00
parent b7acac5e4e
commit 4d0b6d889d

View File

@@ -172,6 +172,22 @@ def main():
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS,
top_k=TOP_K, device=DEVICE,
)
runner.l1_fp4 = l1_fp4; runner.l1_sf = l1_sf; runner.l1_gs = l1_gs_list
# Set L2 weights too (needed for _ensure_stacked)
l2_fp4, l2_sf, l2_gs_list = [], [], []
for e in expert_indices:
dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight"
if dk in nvfp4_tensors:
dw = nvfp4_tensors[dk].to(DEVICE)
dsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
dgs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous())
l2_sf.append(dsf.permute(1,0).contiguous()); l2_gs_list.append(dgs)
else:
l2_fp4.append(torch.zeros(INTERMEDIATE_SIZE//2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
l2_sf.append(torch.ones(INTERMEDIATE_SIZE//16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE))
l2_gs_list.append(1.0)
runner.l2_fp4 = l2_fp4; runner.l2_sf = l2_sf; runner.l2_gs = l2_gs_list
runner._ensure_stacked()
# Just use the runner's scale assembly
l1_gsa = torch.full((NUM_EXPERTS,), l1_gs, dtype=torch.float32, device=DEVICE)