Eliminate weight copies: pass stacked checkpoint tensors directly
Memory optimization for MoE weight processing: Before (3-4 copies of weights in memory): 1. Original checkpoint weights in layer.w13_weight (copy 1) 2. Per-expert permuted copies (copy 2) 3. torch.stack() in runner._ensure_stacked (copy 3) 4. make_b_k_major re-stride (copy 4) 5. Scales: permute then assemble_scales_3d_side un-permutes (wasted) After (1-2 copies): 1. View checkpoint as fp4 (NO copy — byte-preserving view) 2. Pass (E, N, K) stacked tensor directly to runner 3. Runner permutes to (E, K, N) contiguous (copy 1), frees stacked ref 4. make_b_k_major re-strides (copy 2), frees (E, K, N) ref 5. Scales: already (N, K_sf) from checkpoint, call assembly directly 6. Free layer.w13_weight etc. immediately after extracting views Also: assemble_scales_3d_side transposes (K_sf, N)→(N, K_sf) internally, but checkpoint scales are ALREADY (N, K_sf). Skip the double-transpose by calling assemble_raw_scales_2d3d_3d_side directly.
This commit is contained in:
@@ -190,18 +190,55 @@ class CuTeDSLMoERunner:
|
||||
if self._l1_mat_b is not None:
|
||||
return
|
||||
|
||||
# Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation)
|
||||
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
|
||||
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
||||
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
|
||||
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
||||
# Convert weights to kernel format
|
||||
if hasattr(self, 'l1_fp4_stacked') and self.l1_fp4_stacked is not None:
|
||||
# Fast path: pre-stacked 3D tensors in checkpoint format (E, N, K)
|
||||
# Permute to (E, K, N) then make K-major
|
||||
l1_fp4_ekn = self.l1_fp4_stacked.permute(0, 2, 1).contiguous()
|
||||
l2_fp4_ekn = self.l2_fp4_stacked.permute(0, 2, 1).contiguous()
|
||||
# Free stacked checkpoints before make_b_k_major (saves one copy)
|
||||
self.l1_fp4_stacked = None
|
||||
self.l2_fp4_stacked = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self._l1_mat_b = make_b_k_major(l1_fp4_ekn)
|
||||
self._l2_mat_b = make_b_k_major(l2_fp4_ekn)
|
||||
del l1_fp4_ekn, l2_fp4_ekn
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Scales: checkpoint is (E, N, K_sf) — the kernel expects (N, K_sf)
|
||||
# per expert for swizzle. Split into views (no copy), then assemble.
|
||||
l1_sf_list = [self.l1_sf_stacked[i] for i in range(self.num_experts)]
|
||||
l2_sf_list = [self.l2_sf_stacked[i] for i in range(self.num_experts)]
|
||||
self.l1_sf_stacked = None
|
||||
self.l2_sf_stacked = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# assemble_scales_3d_side expects (K_sf, N) per expert and transposes
|
||||
# to (N, K_sf) internally. But our scales are already (N, K_sf) from
|
||||
# the checkpoint! Skip the transpose by calling the assembly directly.
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
assemble_raw_scales_2d3d_3d_side,
|
||||
)
|
||||
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list)
|
||||
self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(l2_sf_list)
|
||||
del l1_sf_list, l2_sf_list
|
||||
else:
|
||||
# Legacy path: per-expert lists
|
||||
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
|
||||
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
||||
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
|
||||
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
# Allocate buffers AFTER JIT compilation
|
||||
@@ -236,6 +273,26 @@ class CuTeDSLMoERunner:
|
||||
self.l2_gs = l2_gs
|
||||
self._l1_mat_b = None
|
||||
|
||||
def prepare_weights_from_stacked(self, l1_fp4_stacked, l1_sf_stacked,
|
||||
l1_gs, l2_fp4_stacked, l2_sf_stacked,
|
||||
l2_gs):
|
||||
"""Prepare weights from pre-stacked 3D tensors (checkpoint format).
|
||||
|
||||
Takes (E, N, K_packed) fp4 and (E, N, K_sf) scale tensors directly
|
||||
from the checkpoint, avoiding the per-expert list→stack round-trip.
|
||||
|
||||
The conversion to K-major and swizzled layout happens in _ensure_stacked.
|
||||
This just stores the tensors for deferred processing.
|
||||
"""
|
||||
# Store in checkpoint format (E, N, K) — _ensure_stacked will convert
|
||||
self.l1_fp4_stacked = l1_fp4_stacked
|
||||
self.l1_sf_stacked = l1_sf_stacked
|
||||
self.l1_gs = l1_gs
|
||||
self.l2_fp4_stacked = l2_fp4_stacked
|
||||
self.l2_sf_stacked = l2_sf_stacked
|
||||
self.l2_gs = l2_gs
|
||||
self._l1_mat_b = None
|
||||
|
||||
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
|
||||
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
|
||||
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
|
||||
|
||||
@@ -114,46 +114,46 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
|
||||
else:
|
||||
w2_input_scale_orig = None
|
||||
|
||||
# Extract and convert weights for CuTeDSL runner
|
||||
# Extract weights from the layer — checkpoint format, no copies yet.
|
||||
# w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused
|
||||
# w2_weight: (E, hidden, intermediate//2) uint8 — down
|
||||
l1_fp4_list = []
|
||||
l1_sf_list = []
|
||||
l1_gs_list = []
|
||||
l2_fp4_list = []
|
||||
l2_sf_list = []
|
||||
l2_gs_list = []
|
||||
# w13_weight_scale: (E, 2*intermediate, hidden//16) fp8
|
||||
# w2_weight_scale: (E, hidden, intermediate//16) fp8
|
||||
w13_uint8 = layer.w13_weight.data # (E, 2*inter, hidden//2)
|
||||
w2_uint8 = layer.w2_weight.data # (E, hidden, intermediate//2)
|
||||
w13_sf = layer.w13_weight_scale.data # (E, 2*inter, hidden//16) = (E, N, K_sf)
|
||||
w2_sf = layer.w2_weight_scale.data # (E, hidden, intermediate//16) = (E, N, K_sf)
|
||||
w13_gs = layer.w13_weight_scale_2.data # (E,) or (E, 2)
|
||||
w2_gs = layer.w2_weight_scale_2.data # (E,) or (E, 2)
|
||||
|
||||
for expert_id in range(num_experts):
|
||||
# L1: gate + up (w13)
|
||||
w13_uint8 = layer.w13_weight.data[expert_id] # (2*inter, hidden//2)
|
||||
w13_sf = layer.w13_weight_scale.data[expert_id] # (2*inter, hidden//16) fp8
|
||||
w13_gs = layer.w13_weight_scale_2.data[expert_id].item() # float32
|
||||
# View as fp4 — byte-preserving, NO copy
|
||||
l1_fp4 = w13_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
|
||||
l2_fp4 = w2_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
|
||||
|
||||
# uint8 → float4_e2m1fn_x2, permute to (K_packed, N) for CuTeDSL
|
||||
l1_w = w13_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
# Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL
|
||||
l1_s = w13_sf.permute(1, 0).contiguous()
|
||||
if l1_s.dtype != torch.float8_e4m3fn:
|
||||
l1_s = l1_s.to(torch.float8_e4m3fn)
|
||||
# Ensure scales are float8_e4m3fn (no copy if already correct dtype)
|
||||
if w13_sf.dtype != torch.float8_e4m3fn:
|
||||
w13_sf = w13_sf.to(torch.float8_e4m3fn)
|
||||
if w2_sf.dtype != torch.float8_e4m3fn:
|
||||
w2_sf = w2_sf.to(torch.float8_e4m3fn)
|
||||
|
||||
l1_fp4_list.append(l1_w)
|
||||
l1_sf_list.append(l1_s)
|
||||
l1_gs_list.append(w13_gs)
|
||||
# Global scales
|
||||
l1_gs_list = w13_gs.tolist()
|
||||
l2_gs_list = w2_gs.tolist()
|
||||
|
||||
# L2: down (w2)
|
||||
w2_uint8 = layer.w2_weight.data[expert_id] # (hidden, intermediate//2)
|
||||
w2_sf = layer.w2_weight_scale.data[expert_id] # (hidden, intermediate//16) fp8
|
||||
w2_gs = layer.w2_weight_scale_2.data[expert_id].item() # float32
|
||||
|
||||
l2_w = w2_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
l2_s = w2_sf.permute(1, 0).contiguous()
|
||||
if l2_s.dtype != torch.float8_e4m3fn:
|
||||
l2_s = l2_s.to(torch.float8_e4m3fn)
|
||||
|
||||
l2_fp4_list.append(l2_w)
|
||||
l2_sf_list.append(l2_s)
|
||||
l2_gs_list.append(w2_gs)
|
||||
# Free original weight tensors IMMEDIATELY.
|
||||
# We have views into the same memory (l1_fp4, l2_fp4), but the runner
|
||||
# will create its own copies in _ensure_stacked. Free the layer refs
|
||||
# now so the memory can be reclaimed when the views are no longer held.
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
layer.w13_weight_scale = None
|
||||
layer.w2_weight_scale = None
|
||||
layer.w13_weight_scale_2 = None
|
||||
layer.w2_weight_scale_2 = None
|
||||
if hasattr(layer, 'w13_input_scale'):
|
||||
layer.w13_input_scale = None
|
||||
if hasattr(layer, 'w2_input_scale'):
|
||||
layer.w2_input_scale = None
|
||||
|
||||
# Create the CuTeDSL runner
|
||||
self._runner = CuTeDSLMoERunner(
|
||||
@@ -165,9 +165,10 @@ class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
|
||||
device=str(device),
|
||||
experts_start_idx=self.local_expert_offset,
|
||||
)
|
||||
self._runner.prepare_weights_direct(
|
||||
l1_fp4_list, l1_sf_list, l1_gs_list,
|
||||
l2_fp4_list, l2_sf_list, l2_gs_list,
|
||||
# Pass stacked tensors in checkpoint format (E, N, K) — no copies needed
|
||||
self._runner.prepare_weights_from_stacked(
|
||||
l1_fp4, w13_sf, l1_gs_list,
|
||||
l2_fp4, w2_sf, l2_gs_list,
|
||||
)
|
||||
if self._swiglu_limit is not None:
|
||||
self._runner.set_swiglu_limit(float(self._swiglu_limit))
|
||||
|
||||
Reference in New Issue
Block a user