fix: in-place prepack to avoid 2× peak memory
torch.stack(packed) held all expert tensors + final stack (~3.5 GiB). Now pre-allocate output and fill in-place — only 1 expert tmp + final tensor in memory at any time.
This commit is contained in:
@@ -142,11 +142,23 @@ def _prepack_weight_sf(weight_sf, N, K, tag):
|
||||
# M=128 — TODO: test with M=1, M=256 to confirm.
|
||||
M_for_layout = 128
|
||||
|
||||
packed = []
|
||||
for e in range(E):
|
||||
packed.append(prepack_sfb(weight_sf[e], M_for_layout, N, K))
|
||||
# Pre-allocate output tensor and fill in-place to avoid 2× peak memory
|
||||
# (torch.stack would hold all expert tensors + the final stack = ~3.5 GiB)
|
||||
packed0 = prepack_sfb(weight_sf[0], M_for_layout, N, K)
|
||||
packed = torch.empty(
|
||||
(E, *packed0.shape),
|
||||
dtype=packed0.dtype,
|
||||
device=packed0.device,
|
||||
)
|
||||
packed[0].copy_(packed0)
|
||||
del packed0
|
||||
|
||||
packed = torch.stack(packed, dim=0).contiguous()
|
||||
for e in range(1, E):
|
||||
tmp = prepack_sfb(weight_sf[e], M_for_layout, N, K)
|
||||
packed[e].copy_(tmp)
|
||||
del tmp
|
||||
|
||||
packed = packed.contiguous()
|
||||
_prepack_weight_sf._cache[cache_key] = packed
|
||||
_prepack_weight_sf._cache_order.append(cache_key)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user