single_shot: memory-efficient MoE loading, lazy Nvfp4Linear init
- MoE expert weights loaded per-expert to GPU (no huge CPU tensors) - Nvfp4Linear finalize_weights deferred (lazy on first forward) - Shared expert weights loaded directly to GPU - Added GPU cache cleanup at start - Fixed shared expert finalize_weights (now lazy)
This commit is contained in:
@@ -516,89 +516,69 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
# =====================================================================
|
||||
def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg):
|
||||
n_e = cfg["n_routed_experts"]
|
||||
w0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight")
|
||||
if w0 is None:
|
||||
# Load expert weights directly to GPU, one at a time
|
||||
# to avoid creating huge CPU tensors
|
||||
l1_fp4_list, l1_sf_list, l1_gs_list = [], [], []
|
||||
l2_fp4_list, l2_sf_list, l2_gs_list = [], [], []
|
||||
for eid in range(n_e):
|
||||
ep = f"{pfx}.experts.{eid}"
|
||||
# L1: gate + up
|
||||
gw, gws, _, gisc = get_nvfp4_weight(all_w, ep, 'gate_proj')
|
||||
uw, uws, _, uisc = get_nvfp4_weight(all_w, ep, 'up_proj')
|
||||
if gw is not None and uw is not None:
|
||||
# Stack gate and up along dim 0 → (2*N, K) uint8
|
||||
l1_fp4_list.append(torch.cat([gw, uw], dim=0).to(dev))
|
||||
if gws is not None and uws is not None:
|
||||
l1_sf_list.append(torch.cat([gws, uws], dim=0).to(dev))
|
||||
gs = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)
|
||||
l1_gs_list.append(gs)
|
||||
# L2: down
|
||||
dw, dws, _, disc = get_nvfp4_weight(all_w, ep, 'down_proj')
|
||||
if dw is not None:
|
||||
l2_fp4_list.append(dw.to(dev))
|
||||
if dws is not None:
|
||||
l2_sf_list.append(dws.to(dev))
|
||||
gs2 = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)
|
||||
l2_gs_list.append(gs2)
|
||||
|
||||
if not l1_fp4_list:
|
||||
log.warning(f"L{li}: No expert weights found")
|
||||
return
|
||||
gate_N, gate_K = w0.shape
|
||||
|
||||
l1_stacked = torch.zeros(n_e, 2 * gate_N, gate_K, dtype=w0.dtype)
|
||||
l1_sf_stacked = None
|
||||
l2_stacked = None
|
||||
l2_sf_stacked = None
|
||||
l1_gs = []
|
||||
l2_gs = []
|
||||
# Stack into (E, N, K) tensors on GPU
|
||||
l1_stacked = torch.stack(l1_fp4_list).to(dev)
|
||||
l1_sf_stacked = torch.stack(l1_sf_list).to(dev) if l1_sf_list else None
|
||||
l2_stacked = torch.stack(l2_fp4_list).to(dev) if l2_fp4_list else None
|
||||
l2_sf_stacked = torch.stack(l2_sf_list).to(dev) if l2_sf_list else None
|
||||
del l1_fp4_list, l1_sf_list, l2_fp4_list, l2_sf_list
|
||||
|
||||
ws0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight_scale")
|
||||
if ws0 is not None:
|
||||
sf_N, sf_K = ws0.shape
|
||||
l1_sf_stacked = torch.zeros(n_e, 2 * sf_N, sf_K, dtype=ws0.dtype)
|
||||
|
||||
dw0 = all_w.get(f"{pfx}.experts.0.down_proj.weight")
|
||||
if dw0 is not None:
|
||||
down_N, down_K = dw0.shape
|
||||
l2_stacked = torch.zeros(n_e, down_N, down_K, dtype=dw0.dtype)
|
||||
dws0 = all_w.get(f"{pfx}.experts.0.down_proj.weight_scale")
|
||||
if dws0 is not None:
|
||||
l2_sf_stacked = torch.zeros(n_e, dws0.shape[0], dws0.shape[1], dtype=dws0.dtype)
|
||||
|
||||
for eid in range(n_e):
|
||||
gw = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight")
|
||||
gws = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight_scale")
|
||||
gisc = all_w.get(f"{pfx}.experts.{eid}.gate_proj.input_scale")
|
||||
uw = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight")
|
||||
uws = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight_scale")
|
||||
if gw is not None and uw is not None:
|
||||
l1_stacked[eid, :gate_N] = gw
|
||||
l1_stacked[eid, gate_N:] = uw
|
||||
if gws is not None and uws is not None and l1_sf_stacked is not None:
|
||||
l1_sf_stacked[eid, :sf_N] = gws
|
||||
l1_sf_stacked[eid, sf_N:] = uws
|
||||
l1_gs.append(gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0))
|
||||
dw = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight")
|
||||
dws = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight_scale")
|
||||
disc = all_w.get(f"{pfx}.experts.{eid}.down_proj.input_scale")
|
||||
if dw is not None:
|
||||
l2_stacked[eid] = dw
|
||||
if dws is not None and l2_sf_stacked is not None:
|
||||
l2_sf_stacked[eid] = dws
|
||||
l2_gs.append(disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0))
|
||||
|
||||
l1_stacked = l1_stacked.to(dev)
|
||||
l1_sf_stacked = l1_sf_stacked.to(dev) if l1_sf_stacked is not None else None
|
||||
l2_stacked = l2_stacked.to(dev) if l2_stacked is not None else None
|
||||
l2_sf_stacked = l2_sf_stacked.to(dev) if l2_sf_stacked is not None else None
|
||||
l1_gs = l1_gs if l1_gs else [1.0 / (6.0 * 448.0)] * n_e
|
||||
l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] * n_e
|
||||
moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs,
|
||||
l2_stacked, l2_sf_stacked, l2_gs)
|
||||
moe.prepare_weights_from_stacked(
|
||||
l1_stacked, l1_sf_stacked, l1_gs_list,
|
||||
l2_stacked, l2_sf_stacked, l2_gs_list,
|
||||
)
|
||||
|
||||
|
||||
def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg):
|
||||
l1_gate_fp4, l1_gate_sf, l1_gate_gs = [], [], []
|
||||
l1_up_fp4, l1_up_sf = [], []
|
||||
l2_fp4, l2_sf, l2_gs = [], [], []
|
||||
for proj, fp4_l, sf_l, gs_l in [
|
||||
('gate_proj', l1_gate_fp4, l1_gate_sf, l1_gate_gs),
|
||||
('up_proj', l1_up_fp4, l1_up_sf, None),
|
||||
('down_proj', l2_fp4, l2_sf, l2_gs),
|
||||
]:
|
||||
w, ws, isc = all_w.get(f"{pfx}.shared_experts.{proj}.weight"), \
|
||||
all_w.get(f"{pfx}.shared_experts.{proj}.weight_scale"), \
|
||||
all_w.get(f"{pfx}.shared_experts.{proj}.input_scale")
|
||||
if w is not None and ws is not None:
|
||||
fp4_l.append(w.to(dev))
|
||||
sf_l.append(ws.to(dev))
|
||||
if gs_l is not None:
|
||||
gs_l.append(isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0))
|
||||
if l1_gate_fp4 and l1_up_fp4:
|
||||
se.l1_fp4 = [torch.cat([l1_gate_fp4[0], l1_up_fp4[0]], dim=0)]
|
||||
se.l1_sf = [torch.cat([l1_gate_sf[0], l1_up_sf[0]], dim=0)]
|
||||
se.l1_gs = l1_gate_gs if l1_gate_gs else [1.0 / (6.0 * 448.0)]
|
||||
if l2_fp4:
|
||||
se.l2_fp4 = l2_fp4; se.l2_sf = l2_sf
|
||||
se.l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)]
|
||||
se.finalize_weights()
|
||||
moe_inter = cfg.get('moe_intermediate_size', 3072)
|
||||
# Shared expert: gate_proj + up_proj → L1, down_proj → L2
|
||||
gw, gws, _, gisc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'gate_proj')
|
||||
uw, uws, _, uisc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'up_proj')
|
||||
dw, dws, _, disc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'down_proj')
|
||||
|
||||
if gw is not None and uw is not None:
|
||||
se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)]
|
||||
if gws is not None and uws is not None:
|
||||
se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)]
|
||||
else:
|
||||
se.l1_sf = [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)]
|
||||
gs = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)
|
||||
se.l1_gs = [gs]
|
||||
if dw is not None:
|
||||
se.l2_fp4 = [dw.to(dev)]
|
||||
se.l2_sf = [dws.to(dev)] if dws is not None else [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)]
|
||||
gs2 = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)
|
||||
se.l2_gs = [gs2]
|
||||
# finalize_weights called lazily by Nvfp4SharedExpert._ensure_initialized()
|
||||
|
||||
|
||||
def _cache_layer_weights_no_experts(all_w, n_layers, devices):
|
||||
@@ -664,6 +644,12 @@ def main():
|
||||
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
|
||||
# Kill stale GPU processes (safety)
|
||||
for g in range(NUM_GPUS):
|
||||
torch.cuda.set_device(g)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# mHC + norms
|
||||
attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {}
|
||||
for li in range(n_layers):
|
||||
@@ -716,7 +702,9 @@ def main():
|
||||
wt, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj)
|
||||
if wt is not None and ws is not None:
|
||||
lin = make_nvfp4_linear(in_f, out_f, dev, wt, ws, ws2, isc)
|
||||
lin.finalize_weights()
|
||||
# Don't finalize yet — defer JIT compilation to first forward call
|
||||
# This avoids allocating GPU workspace for all 61*4=244 projections upfront
|
||||
# lin.finalize_weights() # called lazily by Nvfp4Linear.forward()
|
||||
plin[proj] = lin
|
||||
if plin:
|
||||
prod_lins[li] = plin
|
||||
|
||||
Reference in New Issue
Block a user