fix: separate L1/L2 scale buffers (different K_sf), fix assembly calls

This commit is contained in:
2026-05-17 07:43:05 +00:00
parent b824b838a9
commit 37fecb588f
2 changed files with 36 additions and 20 deletions

View File

@@ -70,7 +70,10 @@ def test_scale_assembly():
# Path 2: _assemble_scales_cudagraph_safe (GPU-only)
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=DEVICE)
expert_offsets[1:] = torch.tensor(tokens_per_expert, dtype=torch.int32).cumsum(0)
scale_a_cudagraph = runner._assemble_scales_cudagraph_safe(x_sf, expert_offsets)
scale_a_cudagraph = runner._assemble_scales_cudagraph_safe(
x_sf, expert_offsets,
runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1
)
# Compare
# Note: shapes may differ due to padding, but the data in the

View File

@@ -69,8 +69,10 @@ class CuTeDSLMoERunner:
self._token_indices = None
self._expert_id_range = None
self._expert_offsets_buf = None
self._per_expert_scale_bufs = None
self._padded_x_sf_buf = None
self._per_expert_scale_bufs_l1 = None
self._per_expert_scale_bufs_l2 = None
self._padded_x_sf_buf_l1 = None
self._padded_x_sf_buf_l2 = None
self._buffers_allocated = False
def _allocate_buffers(self):
@@ -90,17 +92,27 @@ class CuTeDSLMoERunner:
self.num_experts + 1, dtype=torch.int32, device=self.device
)
# Per-expert scale buffers: each expert gets a 128-row block
# This matches assemble_scales_2d_side which pads+swizzles each expert independently
self._per_expert_scale_bufs = [
torch.zeros(128, padded_cols, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
# Per-expert scale buffers: separate L1/L2 since K_sf differs
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
self._per_expert_scale_bufs_l1 = [
torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
for _ in range(self.num_experts)
]
self._per_expert_scale_bufs_l2 = [
torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
for _ in range(self.num_experts)
]
# Padded x_sf buffer: num_experts * 128 rows so that fixed-shape slices
# x_sf[start:start+128] always have 128 rows (extra rows are zeros)
self._padded_x_sf_buf = torch.zeros(
self.num_experts * 128, padded_cols, dtype=torch.float16, device=self.device
# Padded x_sf buffers: num_experts * 128 rows, separate L1/L2
self._padded_x_sf_buf_l1 = torch.zeros(
self.num_experts * 128, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
self._padded_x_sf_buf_l2 = torch.zeros(
self.num_experts * 128, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
self._buffers_allocated = True
@@ -147,21 +159,20 @@ class CuTeDSLMoERunner:
self.l2_gs.append(w_gs)
self._l1_mat_b = None
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets):
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets,
padded_x_sf_buf, per_expert_bufs):
"""Assemble 2D-side activation scales (cudagraph-safe, no CPU sync).
Matches the working assemble_scales_2d_side: pads each expert's scales
to 128 rows, swizzles each expert block independently, then concatenates.
No .item(), no .tolist(), no Python control flow on GPU data.
Each expert's data is placed at 128-row-aligned offsets in padded_x_sf
so that a fixed 128-row slice always contains only that expert's data.
Each expert's data is placed at 128-row-aligned offsets so that a
fixed 128-row slice always contains only that expert's data.
"""
num_experts = self.num_experts
K_sf = x_sf.shape[1]
# Zero the padded buffer, then scatter each expert's rows at 128-aligned offsets
padded_x_sf = self._padded_x_sf_buf
padded_x_sf = padded_x_sf_buf
padded_x_sf.zero_()
for e in range(num_experts):
start = expert_offsets[e]
@@ -173,7 +184,7 @@ class CuTeDSLMoERunner:
# For each expert: zero the per-expert buf, copy from padded, swizzle
swizzled_parts = []
for e in range(num_experts):
buf = self._per_expert_scale_bufs[e]
buf = per_expert_bufs[e]
buf.zero_()
# Copy 128 rows starting at this expert's aligned offset
buf[:, :K_sf] = padded_x_sf[e * 128:e * 128 + 128]
@@ -243,7 +254,8 @@ class CuTeDSLMoERunner:
)
l1_scale_a = self._assemble_scales_cudagraph_safe(
x_sf, expert_offsets[:self.num_experts + 1]
x_sf, expert_offsets[:self.num_experts + 1],
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = torch.full(
(self.num_experts,), self._l1_activation_global_scale,
@@ -268,7 +280,8 @@ class CuTeDSLMoERunner:
)
l2_scale_a = self._assemble_scales_cudagraph_safe(
l2_x_sf, expert_offsets[:self.num_experts + 1]
l2_x_sf, expert_offsets[:self.num_experts + 1],
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
)
l2_gsa = torch.full(
(self.num_experts,), self._l2_activation_global_scale,