Memory optimization for MoE weight processing: Before (3-4 copies of weights in memory): 1. Original checkpoint weights in layer.w13_weight (copy 1) 2. Per-expert permuted copies (copy 2) 3. torch.stack() in runner._ensure_stacked (copy 3) 4. make_b_k_major re-stride (copy 4) 5. Scales: permute then assemble_scales_3d_side un-permutes (wasted) After (1-2 copies): 1. View checkpoint as fp4 (NO copy — byte-preserving view) 2. Pass (E, N, K) stacked tensor directly to runner 3. Runner permutes to (E, K, N) contiguous (copy 1), frees stacked ref 4. make_b_k_major re-strides (copy 2), frees (E, K, N) ref 5. Scales: already (N, K_sf) from checkpoint, call assembly directly 6. Free layer.w13_weight etc. immediately after extracting views Also: assemble_scales_3d_side transposes (K_sf, N)→(N, K_sf) internally, but checkpoint scales are ALREADY (N, K_sf). Skip the double-transpose by calling assemble_raw_scales_2d3d_3d_side directly.
27 KiB
27 KiB