Eliminate weight copies: pass stacked checkpoint tensors directly

Memory optimization for MoE weight processing:

Before (3-4 copies of weights in memory):
1. Original checkpoint weights in layer.w13_weight (copy 1)
2. Per-expert permuted copies (copy 2)
3. torch.stack() in runner._ensure_stacked (copy 3)
4. make_b_k_major re-stride (copy 4)
5. Scales: permute then assemble_scales_3d_side un-permutes (wasted)

After (1-2 copies):
1. View checkpoint as fp4 (NO copy — byte-preserving view)
2. Pass (E, N, K) stacked tensor directly to runner
3. Runner permutes to (E, K, N) contiguous (copy 1), frees stacked ref
4. make_b_k_major re-strides (copy 2), frees (E, K, N) ref
5. Scales: already (N, K_sf) from checkpoint, call assembly directly
6. Free layer.w13_weight etc. immediately after extracting views

Also: assemble_scales_3d_side transposes (K_sf, N)→(N, K_sf) internally,
but checkpoint scales are ALREADY (N, K_sf). Skip the double-transpose
by calling assemble_raw_scales_2d3d_3d_side directly.
This commit is contained in:
2026-05-19 02:16:43 +00:00
parent 0612c1ab54
commit 48fa64dfda
2 changed files with 104 additions and 46 deletions

View File

@@ -190,18 +190,55 @@ class CuTeDSLMoERunner:
if self._l1_mat_b is not None:
return
# Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation)
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
# Convert weights to kernel format
if hasattr(self, 'l1_fp4_stacked') and self.l1_fp4_stacked is not None:
# Fast path: pre-stacked 3D tensors in checkpoint format (E, N, K)
# Permute to (E, K, N) then make K-major
l1_fp4_ekn = self.l1_fp4_stacked.permute(0, 2, 1).contiguous()
l2_fp4_ekn = self.l2_fp4_stacked.permute(0, 2, 1).contiguous()
# Free stacked checkpoints before make_b_k_major (saves one copy)
self.l1_fp4_stacked = None
self.l2_fp4_stacked = None
torch.cuda.empty_cache()
self._l1_mat_b = make_b_k_major(l1_fp4_ekn)
self._l2_mat_b = make_b_k_major(l2_fp4_ekn)
del l1_fp4_ekn, l2_fp4_ekn
torch.cuda.empty_cache()
# Scales: checkpoint is (E, N, K_sf) — the kernel expects (N, K_sf)
# per expert for swizzle. Split into views (no copy), then assemble.
l1_sf_list = [self.l1_sf_stacked[i] for i in range(self.num_experts)]
l2_sf_list = [self.l2_sf_stacked[i] for i in range(self.num_experts)]
self.l1_sf_stacked = None
self.l2_sf_stacked = None
torch.cuda.empty_cache()
# assemble_scales_3d_side expects (K_sf, N) per expert and transposes
# to (N, K_sf) internally. But our scales are already (N, K_sf) from
# the checkpoint! Skip the transpose by calling the assembly directly.
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
assemble_raw_scales_2d3d_3d_side,
)
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list)
self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(l2_sf_list)
del l1_sf_list, l2_sf_list
else:
# Legacy path: per-expert lists
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
# Allocate buffers AFTER JIT compilation
@@ -236,6 +273,26 @@ class CuTeDSLMoERunner:
self.l2_gs = l2_gs
self._l1_mat_b = None
def prepare_weights_from_stacked(self, l1_fp4_stacked, l1_sf_stacked,
l1_gs, l2_fp4_stacked, l2_sf_stacked,
l2_gs):
"""Prepare weights from pre-stacked 3D tensors (checkpoint format).
Takes (E, N, K_packed) fp4 and (E, N, K_sf) scale tensors directly
from the checkpoint, avoiding the per-expert list→stack round-trip.
The conversion to K-major and swizzled layout happens in _ensure_stacked.
This just stores the tensors for deferred processing.
"""
# Store in checkpoint format (E, N, K) — _ensure_stacked will convert
self.l1_fp4_stacked = l1_fp4_stacked
self.l1_sf_stacked = l1_sf_stacked
self.l1_gs = l1_gs
self.l2_fp4_stacked = l2_fp4_stacked
self.l2_sf_stacked = l2_sf_stacked
self.l2_gs = l2_gs
self._l1_mat_b = None
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []