fix: in-place prepack to avoid 2× peak memory

torch.stack(packed) held all expert tensors + final stack (~3.5 GiB).
Now pre-allocate output and fill in-place — only 1 expert tmp + final
tensor in memory at any time.
This commit is contained in:
2026-05-15 10:38:44 +00:00
parent 5dc18df494
commit 7adfaef113

View File

@@ -142,11 +142,23 @@ def _prepack_weight_sf(weight_sf, N, K, tag):
# M=128 — TODO: test with M=1, M=256 to confirm.
M_for_layout = 128
packed = []
for e in range(E):
packed.append(prepack_sfb(weight_sf[e], M_for_layout, N, K))
# Pre-allocate output tensor and fill in-place to avoid 2× peak memory
# (torch.stack would hold all expert tensors + the final stack = ~3.5 GiB)
packed0 = prepack_sfb(weight_sf[0], M_for_layout, N, K)
packed = torch.empty(
(E, *packed0.shape),
dtype=packed0.dtype,
device=packed0.device,
)
packed[0].copy_(packed0)
del packed0
packed = torch.stack(packed, dim=0).contiguous()
for e in range(1, E):
tmp = prepack_sfb(weight_sf[e], M_for_layout, N, K)
packed[e].copy_(tmp)
del tmp
packed = packed.contiguous()
_prepack_weight_sf._cache[cache_key] = packed
_prepack_weight_sf._cache_order.append(cache_key)