Set runner weights before _ensure_stacked
This commit is contained in:
@@ -172,6 +172,22 @@ def main():
|
||||
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS,
|
||||
top_k=TOP_K, device=DEVICE,
|
||||
)
|
||||
runner.l1_fp4 = l1_fp4; runner.l1_sf = l1_sf; runner.l1_gs = l1_gs_list
|
||||
# Set L2 weights too (needed for _ensure_stacked)
|
||||
l2_fp4, l2_sf, l2_gs_list = [], [], []
|
||||
for e in expert_indices:
|
||||
dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight"
|
||||
if dk in nvfp4_tensors:
|
||||
dw = nvfp4_tensors[dk].to(DEVICE)
|
||||
dsf = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
|
||||
dgs = nvfp4_tensors[f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
|
||||
l2_fp4.append(dw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous())
|
||||
l2_sf.append(dsf.permute(1,0).contiguous()); l2_gs_list.append(dgs)
|
||||
else:
|
||||
l2_fp4.append(torch.zeros(INTERMEDIATE_SIZE//2, HIDDEN_SIZE, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
|
||||
l2_sf.append(torch.ones(INTERMEDIATE_SIZE//16, HIDDEN_SIZE, dtype=torch.float8_e4m3fn, device=DEVICE))
|
||||
l2_gs_list.append(1.0)
|
||||
runner.l2_fp4 = l2_fp4; runner.l2_sf = l2_sf; runner.l2_gs = l2_gs_list
|
||||
runner._ensure_stacked()
|
||||
# Just use the runner's scale assembly
|
||||
l1_gsa = torch.full((NUM_EXPERTS,), l1_gs, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
Reference in New Issue
Block a user