Eliminate weight copies: pass stacked checkpoint tensors directly

Memory optimization for MoE weight processing:

Before (3-4 copies of weights in memory):
1. Original checkpoint weights in layer.w13_weight (copy 1)
2. Per-expert permuted copies (copy 2)
3. torch.stack() in runner._ensure_stacked (copy 3)
4. make_b_k_major re-stride (copy 4)
5. Scales: permute then assemble_scales_3d_side un-permutes (wasted)

After (1-2 copies):
1. View checkpoint as fp4 (NO copy — byte-preserving view)
2. Pass (E, N, K) stacked tensor directly to runner
3. Runner permutes to (E, K, N) contiguous (copy 1), frees stacked ref
4. make_b_k_major re-strides (copy 2), frees (E, K, N) ref
5. Scales: already (N, K_sf) from checkpoint, call assembly directly
6. Free layer.w13_weight etc. immediately after extracting views

Also: assemble_scales_3d_side transposes (K_sf, N)→(N, K_sf) internally,
but checkpoint scales are ALREADY (N, K_sf). Skip the double-transpose
by calling assemble_raw_scales_2d3d_3d_side directly.
This commit is contained in:
2026-05-19 02:16:43 +00:00
parent 0612c1ab54
commit 48fa64dfda
2 changed files with 104 additions and 46 deletions

View File

@@ -190,13 +190,45 @@ class CuTeDSLMoERunner:
if self._l1_mat_b is not None: if self._l1_mat_b is not None:
return return
# Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation) # Convert weights to kernel format
if hasattr(self, 'l1_fp4_stacked') and self.l1_fp4_stacked is not None:
# Fast path: pre-stacked 3D tensors in checkpoint format (E, N, K)
# Permute to (E, K, N) then make K-major
l1_fp4_ekn = self.l1_fp4_stacked.permute(0, 2, 1).contiguous()
l2_fp4_ekn = self.l2_fp4_stacked.permute(0, 2, 1).contiguous()
# Free stacked checkpoints before make_b_k_major (saves one copy)
self.l1_fp4_stacked = None
self.l2_fp4_stacked = None
torch.cuda.empty_cache()
self._l1_mat_b = make_b_k_major(l1_fp4_ekn)
self._l2_mat_b = make_b_k_major(l2_fp4_ekn)
del l1_fp4_ekn, l2_fp4_ekn
torch.cuda.empty_cache()
# Scales: checkpoint is (E, N, K_sf) — the kernel expects (N, K_sf)
# per expert for swizzle. Split into views (no copy), then assemble.
l1_sf_list = [self.l1_sf_stacked[i] for i in range(self.num_experts)]
l2_sf_list = [self.l2_sf_stacked[i] for i in range(self.num_experts)]
self.l1_sf_stacked = None
self.l2_sf_stacked = None
torch.cuda.empty_cache()
# assemble_scales_3d_side expects (K_sf, N) per expert and transposes
# to (N, K_sf) internally. But our scales are already (N, K_sf) from
# the checkpoint! Skip the transpose by calling the assembly directly.
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
assemble_raw_scales_2d3d_3d_side,
)
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list)
self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(l2_sf_list)
del l1_sf_list, l2_sf_list
else:
# Legacy path: per-expert lists
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4)) self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf) self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
self.l1_fp4 = None self.l1_fp4 = None
self.l1_sf = None self.l1_sf = None
self.l1_gs = None self.l1_gs = None
@@ -204,6 +236,11 @@ class CuTeDSLMoERunner:
self.l2_sf = None self.l2_sf = None
self.l2_gs = None self.l2_gs = None
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
self.l1_gs = None
self.l2_gs = None
# Allocate buffers AFTER JIT compilation # Allocate buffers AFTER JIT compilation
# (CuTeDSL's cute.compile corrupts GPU memory during JIT; # (CuTeDSL's cute.compile corrupts GPU memory during JIT;
# tensors allocated before/during compilation may be zeroed) # tensors allocated before/during compilation may be zeroed)
@@ -236,6 +273,26 @@ class CuTeDSLMoERunner:
self.l2_gs = l2_gs self.l2_gs = l2_gs
self._l1_mat_b = None self._l1_mat_b = None
def prepare_weights_from_stacked(self, l1_fp4_stacked, l1_sf_stacked,
l1_gs, l2_fp4_stacked, l2_sf_stacked,
l2_gs):
"""Prepare weights from pre-stacked 3D tensors (checkpoint format).
Takes (E, N, K_packed) fp4 and (E, N, K_sf) scale tensors directly
from the checkpoint, avoiding the per-expert list→stack round-trip.
The conversion to K-major and swizzled layout happens in _ensure_stacked.
This just stores the tensors for deferred processing.
"""
# Store in checkpoint format (E, N, K) — _ensure_stacked will convert
self.l1_fp4_stacked = l1_fp4_stacked
self.l1_sf_stacked = l1_sf_stacked
self.l1_gs = l1_gs
self.l2_fp4_stacked = l2_fp4_stacked
self.l2_sf_stacked = l2_sf_stacked
self.l2_gs = l2_gs
self._l1_mat_b = None
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16): def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], [] self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], [] self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []

View File

@@ -114,46 +114,46 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
else: else:
w2_input_scale_orig = None w2_input_scale_orig = None
# Extract and convert weights for CuTeDSL runner # Extract weights from the layer — checkpoint format, no copies yet.
# w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused # w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused
# w2_weight: (E, hidden, intermediate//2) uint8 — down # w2_weight: (E, hidden, intermediate//2) uint8 — down
l1_fp4_list = [] # w13_weight_scale: (E, 2*intermediate, hidden//16) fp8
l1_sf_list = [] # w2_weight_scale: (E, hidden, intermediate//16) fp8
l1_gs_list = [] w13_uint8 = layer.w13_weight.data # (E, 2*inter, hidden//2)
l2_fp4_list = [] w2_uint8 = layer.w2_weight.data # (E, hidden, intermediate//2)
l2_sf_list = [] w13_sf = layer.w13_weight_scale.data # (E, 2*inter, hidden//16) = (E, N, K_sf)
l2_gs_list = [] w2_sf = layer.w2_weight_scale.data # (E, hidden, intermediate//16) = (E, N, K_sf)
w13_gs = layer.w13_weight_scale_2.data # (E,) or (E, 2)
w2_gs = layer.w2_weight_scale_2.data # (E,) or (E, 2)
for expert_id in range(num_experts): # View as fp4 — byte-preserving, NO copy
# L1: gate + up (w13) l1_fp4 = w13_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
w13_uint8 = layer.w13_weight.data[expert_id] # (2*inter, hidden//2) l2_fp4 = w2_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
w13_sf = layer.w13_weight_scale.data[expert_id] # (2*inter, hidden//16) fp8
w13_gs = layer.w13_weight_scale_2.data[expert_id].item() # float32
# uint8 → float4_e2m1fn_x2, permute to (K_packed, N) for CuTeDSL # Ensure scales are float8_e4m3fn (no copy if already correct dtype)
l1_w = w13_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() if w13_sf.dtype != torch.float8_e4m3fn:
# Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL w13_sf = w13_sf.to(torch.float8_e4m3fn)
l1_s = w13_sf.permute(1, 0).contiguous() if w2_sf.dtype != torch.float8_e4m3fn:
if l1_s.dtype != torch.float8_e4m3fn: w2_sf = w2_sf.to(torch.float8_e4m3fn)
l1_s = l1_s.to(torch.float8_e4m3fn)
l1_fp4_list.append(l1_w) # Global scales
l1_sf_list.append(l1_s) l1_gs_list = w13_gs.tolist()
l1_gs_list.append(w13_gs) l2_gs_list = w2_gs.tolist()
# L2: down (w2) # Free original weight tensors IMMEDIATELY.
w2_uint8 = layer.w2_weight.data[expert_id] # (hidden, intermediate//2) # We have views into the same memory (l1_fp4, l2_fp4), but the runner
w2_sf = layer.w2_weight_scale.data[expert_id] # (hidden, intermediate//16) fp8 # will create its own copies in _ensure_stacked. Free the layer refs
w2_gs = layer.w2_weight_scale_2.data[expert_id].item() # float32 # now so the memory can be reclaimed when the views are no longer held.
layer.w13_weight = None
l2_w = w2_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() layer.w2_weight = None
l2_s = w2_sf.permute(1, 0).contiguous() layer.w13_weight_scale = None
if l2_s.dtype != torch.float8_e4m3fn: layer.w2_weight_scale = None
l2_s = l2_s.to(torch.float8_e4m3fn) layer.w13_weight_scale_2 = None
layer.w2_weight_scale_2 = None
l2_fp4_list.append(l2_w) if hasattr(layer, 'w13_input_scale'):
l2_sf_list.append(l2_s) layer.w13_input_scale = None
l2_gs_list.append(w2_gs) if hasattr(layer, 'w2_input_scale'):
layer.w2_input_scale = None
# Create the CuTeDSL runner # Create the CuTeDSL runner
self._runner = CuTeDSLMoERunner( self._runner = CuTeDSLMoERunner(
@@ -165,9 +165,10 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
device=str(device), device=str(device),
experts_start_idx=self.local_expert_offset, experts_start_idx=self.local_expert_offset,
) )
self._runner.prepare_weights_direct( # Pass stacked tensors in checkpoint format (E, N, K) — no copies needed
l1_fp4_list, l1_sf_list, l1_gs_list, self._runner.prepare_weights_from_stacked(
l2_fp4_list, l2_sf_list, l2_gs_list, l1_fp4, w13_sf, l1_gs_list,
l2_fp4, w2_sf, l2_gs_list,
) )
if self._swiglu_limit is not None: if self._swiglu_limit is not None:
self._runner.set_swiglu_limit(float(self._swiglu_limit)) self._runner.set_swiglu_limit(float(self._swiglu_limit))