From 0b35c36d23ca476dc24006e7c3b8ca7a7923d5b8 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 23:16:45 +0000 Subject: [PATCH] single_shot: memory-efficient MoE loading, lazy Nvfp4Linear init - MoE expert weights loaded per-expert to GPU (no huge CPU tensors) - Nvfp4Linear finalize_weights deferred (lazy on first forward) - Shared expert weights loaded directly to GPU - Added GPU cache cleanup at start - Fixed shared expert finalize_weights (now lazy) --- single_shot_inference.py | 142 ++++++++++++++++++--------------------- 1 file changed, 65 insertions(+), 77 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index f9c4443b..1a61c82e 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -516,89 +516,69 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # ===================================================================== def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg): n_e = cfg["n_routed_experts"] - w0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight") - if w0 is None: + # Load expert weights directly to GPU, one at a time + # to avoid creating huge CPU tensors + l1_fp4_list, l1_sf_list, l1_gs_list = [], [], [] + l2_fp4_list, l2_sf_list, l2_gs_list = [], [], [] + for eid in range(n_e): + ep = f"{pfx}.experts.{eid}" + # L1: gate + up + gw, gws, _, gisc = get_nvfp4_weight(all_w, ep, 'gate_proj') + uw, uws, _, uisc = get_nvfp4_weight(all_w, ep, 'up_proj') + if gw is not None and uw is not None: + # Stack gate and up along dim 0 → (2*N, K) uint8 + l1_fp4_list.append(torch.cat([gw, uw], dim=0).to(dev)) + if gws is not None and uws is not None: + l1_sf_list.append(torch.cat([gws, uws], dim=0).to(dev)) + gs = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0) + l1_gs_list.append(gs) + # L2: down + dw, dws, _, disc = get_nvfp4_weight(all_w, ep, 'down_proj') + if dw is not None: + l2_fp4_list.append(dw.to(dev)) + if dws is not None: + l2_sf_list.append(dws.to(dev)) + gs2 = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0) + l2_gs_list.append(gs2) + + if not l1_fp4_list: log.warning(f"L{li}: No expert weights found") return - gate_N, gate_K = w0.shape - l1_stacked = torch.zeros(n_e, 2 * gate_N, gate_K, dtype=w0.dtype) - l1_sf_stacked = None - l2_stacked = None - l2_sf_stacked = None - l1_gs = [] - l2_gs = [] + # Stack into (E, N, K) tensors on GPU + l1_stacked = torch.stack(l1_fp4_list).to(dev) + l1_sf_stacked = torch.stack(l1_sf_list).to(dev) if l1_sf_list else None + l2_stacked = torch.stack(l2_fp4_list).to(dev) if l2_fp4_list else None + l2_sf_stacked = torch.stack(l2_sf_list).to(dev) if l2_sf_list else None + del l1_fp4_list, l1_sf_list, l2_fp4_list, l2_sf_list - ws0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight_scale") - if ws0 is not None: - sf_N, sf_K = ws0.shape - l1_sf_stacked = torch.zeros(n_e, 2 * sf_N, sf_K, dtype=ws0.dtype) - - dw0 = all_w.get(f"{pfx}.experts.0.down_proj.weight") - if dw0 is not None: - down_N, down_K = dw0.shape - l2_stacked = torch.zeros(n_e, down_N, down_K, dtype=dw0.dtype) - dws0 = all_w.get(f"{pfx}.experts.0.down_proj.weight_scale") - if dws0 is not None: - l2_sf_stacked = torch.zeros(n_e, dws0.shape[0], dws0.shape[1], dtype=dws0.dtype) - - for eid in range(n_e): - gw = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight") - gws = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight_scale") - gisc = all_w.get(f"{pfx}.experts.{eid}.gate_proj.input_scale") - uw = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight") - uws = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight_scale") - if gw is not None and uw is not None: - l1_stacked[eid, :gate_N] = gw - l1_stacked[eid, gate_N:] = uw - if gws is not None and uws is not None and l1_sf_stacked is not None: - l1_sf_stacked[eid, :sf_N] = gws - l1_sf_stacked[eid, sf_N:] = uws - l1_gs.append(gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)) - dw = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight") - dws = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight_scale") - disc = all_w.get(f"{pfx}.experts.{eid}.down_proj.input_scale") - if dw is not None: - l2_stacked[eid] = dw - if dws is not None and l2_sf_stacked is not None: - l2_sf_stacked[eid] = dws - l2_gs.append(disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)) - - l1_stacked = l1_stacked.to(dev) - l1_sf_stacked = l1_sf_stacked.to(dev) if l1_sf_stacked is not None else None - l2_stacked = l2_stacked.to(dev) if l2_stacked is not None else None - l2_sf_stacked = l2_sf_stacked.to(dev) if l2_sf_stacked is not None else None - l1_gs = l1_gs if l1_gs else [1.0 / (6.0 * 448.0)] * n_e - l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] * n_e - moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs, - l2_stacked, l2_sf_stacked, l2_gs) + moe.prepare_weights_from_stacked( + l1_stacked, l1_sf_stacked, l1_gs_list, + l2_stacked, l2_sf_stacked, l2_gs_list, + ) def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg): - l1_gate_fp4, l1_gate_sf, l1_gate_gs = [], [], [] - l1_up_fp4, l1_up_sf = [], [] - l2_fp4, l2_sf, l2_gs = [], [], [] - for proj, fp4_l, sf_l, gs_l in [ - ('gate_proj', l1_gate_fp4, l1_gate_sf, l1_gate_gs), - ('up_proj', l1_up_fp4, l1_up_sf, None), - ('down_proj', l2_fp4, l2_sf, l2_gs), - ]: - w, ws, isc = all_w.get(f"{pfx}.shared_experts.{proj}.weight"), \ - all_w.get(f"{pfx}.shared_experts.{proj}.weight_scale"), \ - all_w.get(f"{pfx}.shared_experts.{proj}.input_scale") - if w is not None and ws is not None: - fp4_l.append(w.to(dev)) - sf_l.append(ws.to(dev)) - if gs_l is not None: - gs_l.append(isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0)) - if l1_gate_fp4 and l1_up_fp4: - se.l1_fp4 = [torch.cat([l1_gate_fp4[0], l1_up_fp4[0]], dim=0)] - se.l1_sf = [torch.cat([l1_gate_sf[0], l1_up_sf[0]], dim=0)] - se.l1_gs = l1_gate_gs if l1_gate_gs else [1.0 / (6.0 * 448.0)] - if l2_fp4: - se.l2_fp4 = l2_fp4; se.l2_sf = l2_sf - se.l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] - se.finalize_weights() + moe_inter = cfg.get('moe_intermediate_size', 3072) + # Shared expert: gate_proj + up_proj → L1, down_proj → L2 + gw, gws, _, gisc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'gate_proj') + uw, uws, _, uisc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'up_proj') + dw, dws, _, disc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'down_proj') + + if gw is not None and uw is not None: + se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)] + if gws is not None and uws is not None: + se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)] + else: + se.l1_sf = [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)] + gs = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0) + se.l1_gs = [gs] + if dw is not None: + se.l2_fp4 = [dw.to(dev)] + se.l2_sf = [dws.to(dev)] if dws is not None else [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)] + gs2 = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0) + se.l2_gs = [gs2] + # finalize_weights called lazily by Nvfp4SharedExpert._ensure_initialized() def _cache_layer_weights_no_experts(all_w, n_layers, devices): @@ -664,6 +644,12 @@ def main(): from dsv4.layers.shared_expert import Nvfp4SharedExpert from dsv4.layers.linear import Nvfp4Linear + # Kill stale GPU processes (safety) + for g in range(NUM_GPUS): + torch.cuda.set_device(g) + torch.cuda.empty_cache() + torch.cuda.set_device(0) + # mHC + norms attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {} for li in range(n_layers): @@ -716,7 +702,9 @@ def main(): wt, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj) if wt is not None and ws is not None: lin = make_nvfp4_linear(in_f, out_f, dev, wt, ws, ws2, isc) - lin.finalize_weights() + # Don't finalize yet — defer JIT compilation to first forward call + # This avoids allocating GPU workspace for all 61*4=244 projections upfront + # lin.finalize_weights() # called lazily by Nvfp4Linear.forward() plin[proj] = lin if plin: prod_lins[li] = plin