Eliminate weight copies: pass stacked checkpoint tensors directly

Memory optimization for MoE weight processing:

Before (3-4 copies of weights in memory):
1. Original checkpoint weights in layer.w13_weight (copy 1)
2. Per-expert permuted copies (copy 2)
3. torch.stack() in runner._ensure_stacked (copy 3)
4. make_b_k_major re-stride (copy 4)
5. Scales: permute then assemble_scales_3d_side un-permutes (wasted)

After (1-2 copies):
1. View checkpoint as fp4 (NO copy — byte-preserving view)
2. Pass (E, N, K) stacked tensor directly to runner
3. Runner permutes to (E, K, N) contiguous (copy 1), frees stacked ref
4. make_b_k_major re-strides (copy 2), frees (E, K, N) ref
5. Scales: already (N, K_sf) from checkpoint, call assembly directly
6. Free layer.w13_weight etc. immediately after extracting views

Also: assemble_scales_3d_side transposes (K_sf, N)→(N, K_sf) internally,
but checkpoint scales are ALREADY (N, K_sf). Skip the double-transpose
by calling assemble_raw_scales_2d3d_3d_side directly.
This commit is contained in:
2026-05-19 02:16:43 +00:00
parent 0612c1ab54
commit 48fa64dfda
2 changed files with 104 additions and 46 deletions

View File

@@ -114,46 +114,46 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
else:
w2_input_scale_orig = None
# Extract and convert weights for CuTeDSL runner
# Extract weights from the layer — checkpoint format, no copies yet.
# w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused
# w2_weight: (E, hidden, intermediate//2) uint8 — down
l1_fp4_list = []
l1_sf_list = []
l1_gs_list = []
l2_fp4_list = []
l2_sf_list = []
l2_gs_list = []
# w13_weight_scale: (E, 2*intermediate, hidden//16) fp8
# w2_weight_scale: (E, hidden, intermediate//16) fp8
w13_uint8 = layer.w13_weight.data # (E, 2*inter, hidden//2)
w2_uint8 = layer.w2_weight.data # (E, hidden, intermediate//2)
w13_sf = layer.w13_weight_scale.data # (E, 2*inter, hidden//16) = (E, N, K_sf)
w2_sf = layer.w2_weight_scale.data # (E, hidden, intermediate//16) = (E, N, K_sf)
w13_gs = layer.w13_weight_scale_2.data # (E,) or (E, 2)
w2_gs = layer.w2_weight_scale_2.data # (E,) or (E, 2)
for expert_id in range(num_experts):
# L1: gate + up (w13)
w13_uint8 = layer.w13_weight.data[expert_id] # (2*inter, hidden//2)
w13_sf = layer.w13_weight_scale.data[expert_id] # (2*inter, hidden//16) fp8
w13_gs = layer.w13_weight_scale_2.data[expert_id].item() # float32
# View as fp4 — byte-preserving, NO copy
l1_fp4 = w13_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
l2_fp4 = w2_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
# uint8 → float4_e2m1fn_x2, permute to (K_packed, N) for CuTeDSL
l1_w = w13_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL
l1_s = w13_sf.permute(1, 0).contiguous()
if l1_s.dtype != torch.float8_e4m3fn:
l1_s = l1_s.to(torch.float8_e4m3fn)
# Ensure scales are float8_e4m3fn (no copy if already correct dtype)
if w13_sf.dtype != torch.float8_e4m3fn:
w13_sf = w13_sf.to(torch.float8_e4m3fn)
if w2_sf.dtype != torch.float8_e4m3fn:
w2_sf = w2_sf.to(torch.float8_e4m3fn)
l1_fp4_list.append(l1_w)
l1_sf_list.append(l1_s)
l1_gs_list.append(w13_gs)
# Global scales
l1_gs_list = w13_gs.tolist()
l2_gs_list = w2_gs.tolist()
# L2: down (w2)
w2_uint8 = layer.w2_weight.data[expert_id] # (hidden, intermediate//2)
w2_sf = layer.w2_weight_scale.data[expert_id] # (hidden, intermediate//16) fp8
w2_gs = layer.w2_weight_scale_2.data[expert_id].item() # float32
l2_w = w2_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
l2_s = w2_sf.permute(1, 0).contiguous()
if l2_s.dtype != torch.float8_e4m3fn:
l2_s = l2_s.to(torch.float8_e4m3fn)
l2_fp4_list.append(l2_w)
l2_sf_list.append(l2_s)
l2_gs_list.append(w2_gs)
# Free original weight tensors IMMEDIATELY.
# We have views into the same memory (l1_fp4, l2_fp4), but the runner
# will create its own copies in _ensure_stacked. Free the layer refs
# now so the memory can be reclaimed when the views are no longer held.
layer.w13_weight = None
layer.w2_weight = None
layer.w13_weight_scale = None
layer.w2_weight_scale = None
layer.w13_weight_scale_2 = None
layer.w2_weight_scale_2 = None
if hasattr(layer, 'w13_input_scale'):
layer.w13_input_scale = None
if hasattr(layer, 'w2_input_scale'):
layer.w2_input_scale = None
# Create the CuTeDSL runner
self._runner = CuTeDSLMoERunner(
@@ -165,9 +165,10 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
device=str(device),
experts_start_idx=self.local_expert_offset,
)
self._runner.prepare_weights_direct(
l1_fp4_list, l1_sf_list, l1_gs_list,
l2_fp4_list, l2_sf_list, l2_gs_list,
# Pass stacked tensors in checkpoint format (E, N, K) — no copies needed
self._runner.prepare_weights_from_stacked(
l1_fp4, w13_sf, l1_gs_list,
l2_fp4, w2_sf, l2_gs_list,
)
if self._swiglu_limit is not None:
self._runner.set_swiglu_limit(float(self._swiglu_limit))