diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index e415d33b..12aae590 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -142,11 +142,23 @@ def _prepack_weight_sf(weight_sf, N, K, tag): # M=128 — TODO: test with M=1, M=256 to confirm. M_for_layout = 128 - packed = [] - for e in range(E): - packed.append(prepack_sfb(weight_sf[e], M_for_layout, N, K)) + # Pre-allocate output tensor and fill in-place to avoid 2× peak memory + # (torch.stack would hold all expert tensors + the final stack = ~3.5 GiB) + packed0 = prepack_sfb(weight_sf[0], M_for_layout, N, K) + packed = torch.empty( + (E, *packed0.shape), + dtype=packed0.dtype, + device=packed0.device, + ) + packed[0].copy_(packed0) + del packed0 - packed = torch.stack(packed, dim=0).contiguous() + for e in range(1, E): + tmp = prepack_sfb(weight_sf[e], M_for_layout, N, K) + packed[e].copy_(tmp) + del tmp + + packed = packed.contiguous() _prepack_weight_sf._cache[cache_key] = packed _prepack_weight_sf._cache_order.append(cache_key)