From 48fa64dfda170bb0b60dbeb796acdda62c2089e9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 02:16:43 +0000 Subject: [PATCH] Eliminate weight copies: pass stacked checkpoint tensors directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Memory optimization for MoE weight processing: Before (3-4 copies of weights in memory): 1. Original checkpoint weights in layer.w13_weight (copy 1) 2. Per-expert permuted copies (copy 2) 3. torch.stack() in runner._ensure_stacked (copy 3) 4. make_b_k_major re-stride (copy 4) 5. Scales: permute then assemble_scales_3d_side un-permutes (wasted) After (1-2 copies): 1. View checkpoint as fp4 (NO copy — byte-preserving view) 2. Pass (E, N, K) stacked tensor directly to runner 3. Runner permutes to (E, K, N) contiguous (copy 1), frees stacked ref 4. make_b_k_major re-strides (copy 2), frees (E, K, N) ref 5. Scales: already (N, K_sf) from checkpoint, call assembly directly 6. Free layer.w13_weight etc. immediately after extracting views Also: assemble_scales_3d_side transposes (K_sf, N)→(N, K_sf) internally, but checkpoint scales are ALREADY (N, K_sf). Skip the double-transpose by calling assemble_raw_scales_2d3d_3d_side directly. --- cutedsl/runner.py | 75 ++++++++++++++++--- vllm/patches/fused_moe/experts/cutedsl_moe.py | 75 ++++++++++--------- 2 files changed, 104 insertions(+), 46 deletions(-) diff --git a/cutedsl/runner.py b/cutedsl/runner.py index a3b19d1f..36971e48 100644 --- a/cutedsl/runner.py +++ b/cutedsl/runner.py @@ -190,18 +190,55 @@ class CuTeDSLMoERunner: if self._l1_mat_b is not None: return - # Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation) - self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) - self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4)) - self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) - self._l2_scale_b = assemble_scales_3d_side(self.l2_sf) + # Convert weights to kernel format + if hasattr(self, 'l1_fp4_stacked') and self.l1_fp4_stacked is not None: + # Fast path: pre-stacked 3D tensors in checkpoint format (E, N, K) + # Permute to (E, K, N) then make K-major + l1_fp4_ekn = self.l1_fp4_stacked.permute(0, 2, 1).contiguous() + l2_fp4_ekn = self.l2_fp4_stacked.permute(0, 2, 1).contiguous() + # Free stacked checkpoints before make_b_k_major (saves one copy) + self.l1_fp4_stacked = None + self.l2_fp4_stacked = None + torch.cuda.empty_cache() + + self._l1_mat_b = make_b_k_major(l1_fp4_ekn) + self._l2_mat_b = make_b_k_major(l2_fp4_ekn) + del l1_fp4_ekn, l2_fp4_ekn + torch.cuda.empty_cache() + + # Scales: checkpoint is (E, N, K_sf) — the kernel expects (N, K_sf) + # per expert for swizzle. Split into views (no copy), then assemble. + l1_sf_list = [self.l1_sf_stacked[i] for i in range(self.num_experts)] + l2_sf_list = [self.l2_sf_stacked[i] for i in range(self.num_experts)] + self.l1_sf_stacked = None + self.l2_sf_stacked = None + torch.cuda.empty_cache() + + # assemble_scales_3d_side expects (K_sf, N) per expert and transposes + # to (N, K_sf) internally. But our scales are already (N, K_sf) from + # the checkpoint! Skip the transpose by calling the assembly directly. + from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( + assemble_raw_scales_2d3d_3d_side, + ) + self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list) + self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(l2_sf_list) + del l1_sf_list, l2_sf_list + else: + # Legacy path: per-expert lists + self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) + self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4)) + self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) + self._l2_scale_b = assemble_scales_3d_side(self.l2_sf) + self.l1_fp4 = None + self.l1_sf = None + self.l1_gs = None + self.l2_fp4 = None + self.l2_sf = None + self.l2_gs = None + self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device) self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device) - self.l1_fp4 = None - self.l1_sf = None self.l1_gs = None - self.l2_fp4 = None - self.l2_sf = None self.l2_gs = None # Allocate buffers AFTER JIT compilation @@ -236,6 +273,26 @@ class CuTeDSLMoERunner: self.l2_gs = l2_gs self._l1_mat_b = None + def prepare_weights_from_stacked(self, l1_fp4_stacked, l1_sf_stacked, + l1_gs, l2_fp4_stacked, l2_sf_stacked, + l2_gs): + """Prepare weights from pre-stacked 3D tensors (checkpoint format). + + Takes (E, N, K_packed) fp4 and (E, N, K_sf) scale tensors directly + from the checkpoint, avoiding the per-expert list→stack round-trip. + + The conversion to K-major and swizzled layout happens in _ensure_stacked. + This just stores the tensors for deferred processing. + """ + # Store in checkpoint format (E, N, K) — _ensure_stacked will convert + self.l1_fp4_stacked = l1_fp4_stacked + self.l1_sf_stacked = l1_sf_stacked + self.l1_gs = l1_gs + self.l2_fp4_stacked = l2_fp4_stacked + self.l2_sf_stacked = l2_sf_stacked + self.l2_gs = l2_gs + self._l1_mat_b = None + def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16): self.l1_fp4, self.l1_sf, self.l1_gs = [], [], [] self.l2_fp4, self.l2_sf, self.l2_gs = [], [], [] diff --git a/vllm/patches/fused_moe/experts/cutedsl_moe.py b/vllm/patches/fused_moe/experts/cutedsl_moe.py index c4c540af..7580efc6 100644 --- a/vllm/patches/fused_moe/experts/cutedsl_moe.py +++ b/vllm/patches/fused_moe/experts/cutedsl_moe.py @@ -114,46 +114,46 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular): else: w2_input_scale_orig = None - # Extract and convert weights for CuTeDSL runner + # Extract weights from the layer — checkpoint format, no copies yet. # w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused # w2_weight: (E, hidden, intermediate//2) uint8 — down - l1_fp4_list = [] - l1_sf_list = [] - l1_gs_list = [] - l2_fp4_list = [] - l2_sf_list = [] - l2_gs_list = [] + # w13_weight_scale: (E, 2*intermediate, hidden//16) fp8 + # w2_weight_scale: (E, hidden, intermediate//16) fp8 + w13_uint8 = layer.w13_weight.data # (E, 2*inter, hidden//2) + w2_uint8 = layer.w2_weight.data # (E, hidden, intermediate//2) + w13_sf = layer.w13_weight_scale.data # (E, 2*inter, hidden//16) = (E, N, K_sf) + w2_sf = layer.w2_weight_scale.data # (E, hidden, intermediate//16) = (E, N, K_sf) + w13_gs = layer.w13_weight_scale_2.data # (E,) or (E, 2) + w2_gs = layer.w2_weight_scale_2.data # (E,) or (E, 2) - for expert_id in range(num_experts): - # L1: gate + up (w13) - w13_uint8 = layer.w13_weight.data[expert_id] # (2*inter, hidden//2) - w13_sf = layer.w13_weight_scale.data[expert_id] # (2*inter, hidden//16) fp8 - w13_gs = layer.w13_weight_scale_2.data[expert_id].item() # float32 + # View as fp4 — byte-preserving, NO copy + l1_fp4 = w13_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed) + l2_fp4 = w2_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed) - # uint8 → float4_e2m1fn_x2, permute to (K_packed, N) for CuTeDSL - l1_w = w13_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() - # Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL - l1_s = w13_sf.permute(1, 0).contiguous() - if l1_s.dtype != torch.float8_e4m3fn: - l1_s = l1_s.to(torch.float8_e4m3fn) + # Ensure scales are float8_e4m3fn (no copy if already correct dtype) + if w13_sf.dtype != torch.float8_e4m3fn: + w13_sf = w13_sf.to(torch.float8_e4m3fn) + if w2_sf.dtype != torch.float8_e4m3fn: + w2_sf = w2_sf.to(torch.float8_e4m3fn) - l1_fp4_list.append(l1_w) - l1_sf_list.append(l1_s) - l1_gs_list.append(w13_gs) + # Global scales + l1_gs_list = w13_gs.tolist() + l2_gs_list = w2_gs.tolist() - # L2: down (w2) - w2_uint8 = layer.w2_weight.data[expert_id] # (hidden, intermediate//2) - w2_sf = layer.w2_weight_scale.data[expert_id] # (hidden, intermediate//16) fp8 - w2_gs = layer.w2_weight_scale_2.data[expert_id].item() # float32 - - l2_w = w2_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() - l2_s = w2_sf.permute(1, 0).contiguous() - if l2_s.dtype != torch.float8_e4m3fn: - l2_s = l2_s.to(torch.float8_e4m3fn) - - l2_fp4_list.append(l2_w) - l2_sf_list.append(l2_s) - l2_gs_list.append(w2_gs) + # Free original weight tensors IMMEDIATELY. + # We have views into the same memory (l1_fp4, l2_fp4), but the runner + # will create its own copies in _ensure_stacked. Free the layer refs + # now so the memory can be reclaimed when the views are no longer held. + layer.w13_weight = None + layer.w2_weight = None + layer.w13_weight_scale = None + layer.w2_weight_scale = None + layer.w13_weight_scale_2 = None + layer.w2_weight_scale_2 = None + if hasattr(layer, 'w13_input_scale'): + layer.w13_input_scale = None + if hasattr(layer, 'w2_input_scale'): + layer.w2_input_scale = None # Create the CuTeDSL runner self._runner = CuTeDSLMoERunner( @@ -165,9 +165,10 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular): device=str(device), experts_start_idx=self.local_expert_offset, ) - self._runner.prepare_weights_direct( - l1_fp4_list, l1_sf_list, l1_gs_list, - l2_fp4_list, l2_sf_list, l2_gs_list, + # Pass stacked tensors in checkpoint format (E, N, K) — no copies needed + self._runner.prepare_weights_from_stacked( + l1_fp4, w13_sf, l1_gs_list, + l2_fp4, w2_sf, l2_gs_list, ) if self._swiglu_limit is not None: self._runner.set_swiglu_limit(float(self._swiglu_limit))