single_shot: memory-efficient MoE loading, lazy Nvfp4Linear init

- MoE expert weights loaded per-expert to GPU (no huge CPU tensors)
- Nvfp4Linear finalize_weights deferred (lazy on first forward)
- Shared expert weights loaded directly to GPU
- Added GPU cache cleanup at start
- Fixed shared expert finalize_weights (now lazy)
This commit is contained in:
2026-05-31 23:16:45 +00:00
parent 050b5ee449
commit 0b35c36d23

View File

@@ -516,89 +516,69 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
# =====================================================================
def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg):
n_e = cfg["n_routed_experts"]
w0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight")
if w0 is None:
# Load expert weights directly to GPU, one at a time
# to avoid creating huge CPU tensors
l1_fp4_list, l1_sf_list, l1_gs_list = [], [], []
l2_fp4_list, l2_sf_list, l2_gs_list = [], [], []
for eid in range(n_e):
ep = f"{pfx}.experts.{eid}"
# L1: gate + up
gw, gws, _, gisc = get_nvfp4_weight(all_w, ep, 'gate_proj')
uw, uws, _, uisc = get_nvfp4_weight(all_w, ep, 'up_proj')
if gw is not None and uw is not None:
# Stack gate and up along dim 0 → (2*N, K) uint8
l1_fp4_list.append(torch.cat([gw, uw], dim=0).to(dev))
if gws is not None and uws is not None:
l1_sf_list.append(torch.cat([gws, uws], dim=0).to(dev))
gs = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)
l1_gs_list.append(gs)
# L2: down
dw, dws, _, disc = get_nvfp4_weight(all_w, ep, 'down_proj')
if dw is not None:
l2_fp4_list.append(dw.to(dev))
if dws is not None:
l2_sf_list.append(dws.to(dev))
gs2 = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)
l2_gs_list.append(gs2)
if not l1_fp4_list:
log.warning(f"L{li}: No expert weights found")
return
gate_N, gate_K = w0.shape
l1_stacked = torch.zeros(n_e, 2 * gate_N, gate_K, dtype=w0.dtype)
l1_sf_stacked = None
l2_stacked = None
l2_sf_stacked = None
l1_gs = []
l2_gs = []
# Stack into (E, N, K) tensors on GPU
l1_stacked = torch.stack(l1_fp4_list).to(dev)
l1_sf_stacked = torch.stack(l1_sf_list).to(dev) if l1_sf_list else None
l2_stacked = torch.stack(l2_fp4_list).to(dev) if l2_fp4_list else None
l2_sf_stacked = torch.stack(l2_sf_list).to(dev) if l2_sf_list else None
del l1_fp4_list, l1_sf_list, l2_fp4_list, l2_sf_list
ws0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight_scale")
if ws0 is not None:
sf_N, sf_K = ws0.shape
l1_sf_stacked = torch.zeros(n_e, 2 * sf_N, sf_K, dtype=ws0.dtype)
dw0 = all_w.get(f"{pfx}.experts.0.down_proj.weight")
if dw0 is not None:
down_N, down_K = dw0.shape
l2_stacked = torch.zeros(n_e, down_N, down_K, dtype=dw0.dtype)
dws0 = all_w.get(f"{pfx}.experts.0.down_proj.weight_scale")
if dws0 is not None:
l2_sf_stacked = torch.zeros(n_e, dws0.shape[0], dws0.shape[1], dtype=dws0.dtype)
for eid in range(n_e):
gw = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight")
gws = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight_scale")
gisc = all_w.get(f"{pfx}.experts.{eid}.gate_proj.input_scale")
uw = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight")
uws = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight_scale")
if gw is not None and uw is not None:
l1_stacked[eid, :gate_N] = gw
l1_stacked[eid, gate_N:] = uw
if gws is not None and uws is not None and l1_sf_stacked is not None:
l1_sf_stacked[eid, :sf_N] = gws
l1_sf_stacked[eid, sf_N:] = uws
l1_gs.append(gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0))
dw = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight")
dws = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight_scale")
disc = all_w.get(f"{pfx}.experts.{eid}.down_proj.input_scale")
if dw is not None:
l2_stacked[eid] = dw
if dws is not None and l2_sf_stacked is not None:
l2_sf_stacked[eid] = dws
l2_gs.append(disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0))
l1_stacked = l1_stacked.to(dev)
l1_sf_stacked = l1_sf_stacked.to(dev) if l1_sf_stacked is not None else None
l2_stacked = l2_stacked.to(dev) if l2_stacked is not None else None
l2_sf_stacked = l2_sf_stacked.to(dev) if l2_sf_stacked is not None else None
l1_gs = l1_gs if l1_gs else [1.0 / (6.0 * 448.0)] * n_e
l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] * n_e
moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs,
l2_stacked, l2_sf_stacked, l2_gs)
moe.prepare_weights_from_stacked(
l1_stacked, l1_sf_stacked, l1_gs_list,
l2_stacked, l2_sf_stacked, l2_gs_list,
)
def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg):
l1_gate_fp4, l1_gate_sf, l1_gate_gs = [], [], []
l1_up_fp4, l1_up_sf = [], []
l2_fp4, l2_sf, l2_gs = [], [], []
for proj, fp4_l, sf_l, gs_l in [
('gate_proj', l1_gate_fp4, l1_gate_sf, l1_gate_gs),
('up_proj', l1_up_fp4, l1_up_sf, None),
('down_proj', l2_fp4, l2_sf, l2_gs),
]:
w, ws, isc = all_w.get(f"{pfx}.shared_experts.{proj}.weight"), \
all_w.get(f"{pfx}.shared_experts.{proj}.weight_scale"), \
all_w.get(f"{pfx}.shared_experts.{proj}.input_scale")
if w is not None and ws is not None:
fp4_l.append(w.to(dev))
sf_l.append(ws.to(dev))
if gs_l is not None:
gs_l.append(isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0))
if l1_gate_fp4 and l1_up_fp4:
se.l1_fp4 = [torch.cat([l1_gate_fp4[0], l1_up_fp4[0]], dim=0)]
se.l1_sf = [torch.cat([l1_gate_sf[0], l1_up_sf[0]], dim=0)]
se.l1_gs = l1_gate_gs if l1_gate_gs else [1.0 / (6.0 * 448.0)]
if l2_fp4:
se.l2_fp4 = l2_fp4; se.l2_sf = l2_sf
se.l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)]
se.finalize_weights()
moe_inter = cfg.get('moe_intermediate_size', 3072)
# Shared expert: gate_proj + up_proj → L1, down_proj → L2
gw, gws, _, gisc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'gate_proj')
uw, uws, _, uisc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'up_proj')
dw, dws, _, disc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'down_proj')
if gw is not None and uw is not None:
se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)]
if gws is not None and uws is not None:
se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)]
else:
se.l1_sf = [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)]
gs = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)
se.l1_gs = [gs]
if dw is not None:
se.l2_fp4 = [dw.to(dev)]
se.l2_sf = [dws.to(dev)] if dws is not None else [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)]
gs2 = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)
se.l2_gs = [gs2]
# finalize_weights called lazily by Nvfp4SharedExpert._ensure_initialized()
def _cache_layer_weights_no_experts(all_w, n_layers, devices):
@@ -664,6 +644,12 @@ def main():
from dsv4.layers.shared_expert import Nvfp4SharedExpert
from dsv4.layers.linear import Nvfp4Linear
# Kill stale GPU processes (safety)
for g in range(NUM_GPUS):
torch.cuda.set_device(g)
torch.cuda.empty_cache()
torch.cuda.set_device(0)
# mHC + norms
attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {}
for li in range(n_layers):
@@ -716,7 +702,9 @@ def main():
wt, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj)
if wt is not None and ws is not None:
lin = make_nvfp4_linear(in_f, out_f, dev, wt, ws, ws2, isc)
lin.finalize_weights()
# Don't finalize yet — defer JIT compilation to first forward call
# This avoids allocating GPU workspace for all 61*4=244 projections upfront
# lin.finalize_weights() # called lazily by Nvfp4Linear.forward()
plin[proj] = lin
if plin:
prod_lins[li] = plin