Eliminate weight copies: pass stacked checkpoint tensors directly
Memory optimization for MoE weight processing: Before (3-4 copies of weights in memory): 1. Original checkpoint weights in layer.w13_weight (copy 1) 2. Per-expert permuted copies (copy 2) 3. torch.stack() in runner._ensure_stacked (copy 3) 4. make_b_k_major re-stride (copy 4) 5. Scales: permute then assemble_scales_3d_side un-permutes (wasted) After (1-2 copies): 1. View checkpoint as fp4 (NO copy — byte-preserving view) 2. Pass (E, N, K) stacked tensor directly to runner 3. Runner permutes to (E, K, N) contiguous (copy 1), frees stacked ref 4. make_b_k_major re-strides (copy 2), frees (E, K, N) ref 5. Scales: already (N, K_sf) from checkpoint, call assembly directly 6. Free layer.w13_weight etc. immediately after extracting views Also: assemble_scales_3d_side transposes (K_sf, N)→(N, K_sf) internally, but checkpoint scales are ALREADY (N, K_sf). Skip the double-transpose by calling assemble_raw_scales_2d3d_3d_side directly.
This commit is contained in:
@@ -190,18 +190,55 @@ class CuTeDSLMoERunner:
|
|||||||
if self._l1_mat_b is not None:
|
if self._l1_mat_b is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation)
|
# Convert weights to kernel format
|
||||||
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
|
if hasattr(self, 'l1_fp4_stacked') and self.l1_fp4_stacked is not None:
|
||||||
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
# Fast path: pre-stacked 3D tensors in checkpoint format (E, N, K)
|
||||||
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
|
# Permute to (E, K, N) then make K-major
|
||||||
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
l1_fp4_ekn = self.l1_fp4_stacked.permute(0, 2, 1).contiguous()
|
||||||
|
l2_fp4_ekn = self.l2_fp4_stacked.permute(0, 2, 1).contiguous()
|
||||||
|
# Free stacked checkpoints before make_b_k_major (saves one copy)
|
||||||
|
self.l1_fp4_stacked = None
|
||||||
|
self.l2_fp4_stacked = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self._l1_mat_b = make_b_k_major(l1_fp4_ekn)
|
||||||
|
self._l2_mat_b = make_b_k_major(l2_fp4_ekn)
|
||||||
|
del l1_fp4_ekn, l2_fp4_ekn
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Scales: checkpoint is (E, N, K_sf) — the kernel expects (N, K_sf)
|
||||||
|
# per expert for swizzle. Split into views (no copy), then assemble.
|
||||||
|
l1_sf_list = [self.l1_sf_stacked[i] for i in range(self.num_experts)]
|
||||||
|
l2_sf_list = [self.l2_sf_stacked[i] for i in range(self.num_experts)]
|
||||||
|
self.l1_sf_stacked = None
|
||||||
|
self.l2_sf_stacked = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# assemble_scales_3d_side expects (K_sf, N) per expert and transposes
|
||||||
|
# to (N, K_sf) internally. But our scales are already (N, K_sf) from
|
||||||
|
# the checkpoint! Skip the transpose by calling the assembly directly.
|
||||||
|
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||||
|
assemble_raw_scales_2d3d_3d_side,
|
||||||
|
)
|
||||||
|
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list)
|
||||||
|
self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(l2_sf_list)
|
||||||
|
del l1_sf_list, l2_sf_list
|
||||||
|
else:
|
||||||
|
# Legacy path: per-expert lists
|
||||||
|
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
|
||||||
|
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
||||||
|
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
|
||||||
|
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
||||||
|
self.l1_fp4 = None
|
||||||
|
self.l1_sf = None
|
||||||
|
self.l1_gs = None
|
||||||
|
self.l2_fp4 = None
|
||||||
|
self.l2_sf = None
|
||||||
|
self.l2_gs = None
|
||||||
|
|
||||||
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
||||||
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
|
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
|
||||||
self.l1_fp4 = None
|
|
||||||
self.l1_sf = None
|
|
||||||
self.l1_gs = None
|
self.l1_gs = None
|
||||||
self.l2_fp4 = None
|
|
||||||
self.l2_sf = None
|
|
||||||
self.l2_gs = None
|
self.l2_gs = None
|
||||||
|
|
||||||
# Allocate buffers AFTER JIT compilation
|
# Allocate buffers AFTER JIT compilation
|
||||||
@@ -236,6 +273,26 @@ class CuTeDSLMoERunner:
|
|||||||
self.l2_gs = l2_gs
|
self.l2_gs = l2_gs
|
||||||
self._l1_mat_b = None
|
self._l1_mat_b = None
|
||||||
|
|
||||||
|
def prepare_weights_from_stacked(self, l1_fp4_stacked, l1_sf_stacked,
|
||||||
|
l1_gs, l2_fp4_stacked, l2_sf_stacked,
|
||||||
|
l2_gs):
|
||||||
|
"""Prepare weights from pre-stacked 3D tensors (checkpoint format).
|
||||||
|
|
||||||
|
Takes (E, N, K_packed) fp4 and (E, N, K_sf) scale tensors directly
|
||||||
|
from the checkpoint, avoiding the per-expert list→stack round-trip.
|
||||||
|
|
||||||
|
The conversion to K-major and swizzled layout happens in _ensure_stacked.
|
||||||
|
This just stores the tensors for deferred processing.
|
||||||
|
"""
|
||||||
|
# Store in checkpoint format (E, N, K) — _ensure_stacked will convert
|
||||||
|
self.l1_fp4_stacked = l1_fp4_stacked
|
||||||
|
self.l1_sf_stacked = l1_sf_stacked
|
||||||
|
self.l1_gs = l1_gs
|
||||||
|
self.l2_fp4_stacked = l2_fp4_stacked
|
||||||
|
self.l2_sf_stacked = l2_sf_stacked
|
||||||
|
self.l2_gs = l2_gs
|
||||||
|
self._l1_mat_b = None
|
||||||
|
|
||||||
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
|
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
|
||||||
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
|
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
|
||||||
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
|
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
|
||||||
|
|||||||
@@ -114,46 +114,46 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
|
|||||||
else:
|
else:
|
||||||
w2_input_scale_orig = None
|
w2_input_scale_orig = None
|
||||||
|
|
||||||
# Extract and convert weights for CuTeDSL runner
|
# Extract weights from the layer — checkpoint format, no copies yet.
|
||||||
# w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused
|
# w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused
|
||||||
# w2_weight: (E, hidden, intermediate//2) uint8 — down
|
# w2_weight: (E, hidden, intermediate//2) uint8 — down
|
||||||
l1_fp4_list = []
|
# w13_weight_scale: (E, 2*intermediate, hidden//16) fp8
|
||||||
l1_sf_list = []
|
# w2_weight_scale: (E, hidden, intermediate//16) fp8
|
||||||
l1_gs_list = []
|
w13_uint8 = layer.w13_weight.data # (E, 2*inter, hidden//2)
|
||||||
l2_fp4_list = []
|
w2_uint8 = layer.w2_weight.data # (E, hidden, intermediate//2)
|
||||||
l2_sf_list = []
|
w13_sf = layer.w13_weight_scale.data # (E, 2*inter, hidden//16) = (E, N, K_sf)
|
||||||
l2_gs_list = []
|
w2_sf = layer.w2_weight_scale.data # (E, hidden, intermediate//16) = (E, N, K_sf)
|
||||||
|
w13_gs = layer.w13_weight_scale_2.data # (E,) or (E, 2)
|
||||||
|
w2_gs = layer.w2_weight_scale_2.data # (E,) or (E, 2)
|
||||||
|
|
||||||
for expert_id in range(num_experts):
|
# View as fp4 — byte-preserving, NO copy
|
||||||
# L1: gate + up (w13)
|
l1_fp4 = w13_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
|
||||||
w13_uint8 = layer.w13_weight.data[expert_id] # (2*inter, hidden//2)
|
l2_fp4 = w2_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
|
||||||
w13_sf = layer.w13_weight_scale.data[expert_id] # (2*inter, hidden//16) fp8
|
|
||||||
w13_gs = layer.w13_weight_scale_2.data[expert_id].item() # float32
|
|
||||||
|
|
||||||
# uint8 → float4_e2m1fn_x2, permute to (K_packed, N) for CuTeDSL
|
# Ensure scales are float8_e4m3fn (no copy if already correct dtype)
|
||||||
l1_w = w13_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
if w13_sf.dtype != torch.float8_e4m3fn:
|
||||||
# Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL
|
w13_sf = w13_sf.to(torch.float8_e4m3fn)
|
||||||
l1_s = w13_sf.permute(1, 0).contiguous()
|
if w2_sf.dtype != torch.float8_e4m3fn:
|
||||||
if l1_s.dtype != torch.float8_e4m3fn:
|
w2_sf = w2_sf.to(torch.float8_e4m3fn)
|
||||||
l1_s = l1_s.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
l1_fp4_list.append(l1_w)
|
# Global scales
|
||||||
l1_sf_list.append(l1_s)
|
l1_gs_list = w13_gs.tolist()
|
||||||
l1_gs_list.append(w13_gs)
|
l2_gs_list = w2_gs.tolist()
|
||||||
|
|
||||||
# L2: down (w2)
|
# Free original weight tensors IMMEDIATELY.
|
||||||
w2_uint8 = layer.w2_weight.data[expert_id] # (hidden, intermediate//2)
|
# We have views into the same memory (l1_fp4, l2_fp4), but the runner
|
||||||
w2_sf = layer.w2_weight_scale.data[expert_id] # (hidden, intermediate//16) fp8
|
# will create its own copies in _ensure_stacked. Free the layer refs
|
||||||
w2_gs = layer.w2_weight_scale_2.data[expert_id].item() # float32
|
# now so the memory can be reclaimed when the views are no longer held.
|
||||||
|
layer.w13_weight = None
|
||||||
l2_w = w2_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
layer.w2_weight = None
|
||||||
l2_s = w2_sf.permute(1, 0).contiguous()
|
layer.w13_weight_scale = None
|
||||||
if l2_s.dtype != torch.float8_e4m3fn:
|
layer.w2_weight_scale = None
|
||||||
l2_s = l2_s.to(torch.float8_e4m3fn)
|
layer.w13_weight_scale_2 = None
|
||||||
|
layer.w2_weight_scale_2 = None
|
||||||
l2_fp4_list.append(l2_w)
|
if hasattr(layer, 'w13_input_scale'):
|
||||||
l2_sf_list.append(l2_s)
|
layer.w13_input_scale = None
|
||||||
l2_gs_list.append(w2_gs)
|
if hasattr(layer, 'w2_input_scale'):
|
||||||
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
# Create the CuTeDSL runner
|
# Create the CuTeDSL runner
|
||||||
self._runner = CuTeDSLMoERunner(
|
self._runner = CuTeDSLMoERunner(
|
||||||
@@ -165,9 +165,10 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
|
|||||||
device=str(device),
|
device=str(device),
|
||||||
experts_start_idx=self.local_expert_offset,
|
experts_start_idx=self.local_expert_offset,
|
||||||
)
|
)
|
||||||
self._runner.prepare_weights_direct(
|
# Pass stacked tensors in checkpoint format (E, N, K) — no copies needed
|
||||||
l1_fp4_list, l1_sf_list, l1_gs_list,
|
self._runner.prepare_weights_from_stacked(
|
||||||
l2_fp4_list, l2_sf_list, l2_gs_list,
|
l1_fp4, w13_sf, l1_gs_list,
|
||||||
|
l2_fp4, w2_sf, l2_gs_list,
|
||||||
)
|
)
|
||||||
if self._swiglu_limit is not None:
|
if self._swiglu_limit is not None:
|
||||||
self._runner.set_swiglu_limit(float(self._swiglu_limit))
|
self._runner.set_swiglu_limit(float(self._swiglu_limit))
|
||||||
|
|||||||
Reference in New Issue
Block a user