fix: separate L1/L2 scale buffers (different K_sf), fix assembly calls
This commit is contained in:
@@ -70,7 +70,10 @@ def test_scale_assembly():
|
||||
# Path 2: _assemble_scales_cudagraph_safe (GPU-only)
|
||||
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=DEVICE)
|
||||
expert_offsets[1:] = torch.tensor(tokens_per_expert, dtype=torch.int32).cumsum(0)
|
||||
scale_a_cudagraph = runner._assemble_scales_cudagraph_safe(x_sf, expert_offsets)
|
||||
scale_a_cudagraph = runner._assemble_scales_cudagraph_safe(
|
||||
x_sf, expert_offsets,
|
||||
runner._padded_x_sf_buf_l1, runner._per_expert_scale_bufs_l1
|
||||
)
|
||||
|
||||
# Compare
|
||||
# Note: shapes may differ due to padding, but the data in the
|
||||
|
||||
@@ -69,8 +69,10 @@ class CuTeDSLMoERunner:
|
||||
self._token_indices = None
|
||||
self._expert_id_range = None
|
||||
self._expert_offsets_buf = None
|
||||
self._per_expert_scale_bufs = None
|
||||
self._padded_x_sf_buf = None
|
||||
self._per_expert_scale_bufs_l1 = None
|
||||
self._per_expert_scale_bufs_l2 = None
|
||||
self._padded_x_sf_buf_l1 = None
|
||||
self._padded_x_sf_buf_l2 = None
|
||||
self._buffers_allocated = False
|
||||
|
||||
def _allocate_buffers(self):
|
||||
@@ -90,17 +92,27 @@ class CuTeDSLMoERunner:
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Per-expert scale buffers: each expert gets a 128-row block
|
||||
# This matches assemble_scales_2d_side which pads+swizzles each expert independently
|
||||
self._per_expert_scale_bufs = [
|
||||
torch.zeros(128, padded_cols, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
||||
# Per-expert scale buffers: separate L1/L2 since K_sf differs
|
||||
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
|
||||
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
|
||||
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
|
||||
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
|
||||
|
||||
self._per_expert_scale_bufs_l1 = [
|
||||
torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
self._per_expert_scale_bufs_l2 = [
|
||||
torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
# Padded x_sf buffer: num_experts * 128 rows so that fixed-shape slices
|
||||
# x_sf[start:start+128] always have 128 rows (extra rows are zeros)
|
||||
self._padded_x_sf_buf = torch.zeros(
|
||||
self.num_experts * 128, padded_cols, dtype=torch.float16, device=self.device
|
||||
# Padded x_sf buffers: num_experts * 128 rows, separate L1/L2
|
||||
self._padded_x_sf_buf_l1 = torch.zeros(
|
||||
self.num_experts * 128, padded_cols_l1, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
self._padded_x_sf_buf_l2 = torch.zeros(
|
||||
self.num_experts * 128, padded_cols_l2, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
|
||||
self._buffers_allocated = True
|
||||
@@ -147,21 +159,20 @@ class CuTeDSLMoERunner:
|
||||
self.l2_gs.append(w_gs)
|
||||
self._l1_mat_b = None
|
||||
|
||||
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets):
|
||||
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets,
|
||||
padded_x_sf_buf, per_expert_bufs):
|
||||
"""Assemble 2D-side activation scales (cudagraph-safe, no CPU sync).
|
||||
|
||||
Matches the working assemble_scales_2d_side: pads each expert's scales
|
||||
to 128 rows, swizzles each expert block independently, then concatenates.
|
||||
No .item(), no .tolist(), no Python control flow on GPU data.
|
||||
|
||||
Each expert's data is placed at 128-row-aligned offsets in padded_x_sf
|
||||
so that a fixed 128-row slice always contains only that expert's data.
|
||||
Each expert's data is placed at 128-row-aligned offsets so that a
|
||||
fixed 128-row slice always contains only that expert's data.
|
||||
"""
|
||||
num_experts = self.num_experts
|
||||
K_sf = x_sf.shape[1]
|
||||
|
||||
# Zero the padded buffer, then scatter each expert's rows at 128-aligned offsets
|
||||
padded_x_sf = self._padded_x_sf_buf
|
||||
padded_x_sf = padded_x_sf_buf
|
||||
padded_x_sf.zero_()
|
||||
for e in range(num_experts):
|
||||
start = expert_offsets[e]
|
||||
@@ -173,7 +184,7 @@ class CuTeDSLMoERunner:
|
||||
# For each expert: zero the per-expert buf, copy from padded, swizzle
|
||||
swizzled_parts = []
|
||||
for e in range(num_experts):
|
||||
buf = self._per_expert_scale_bufs[e]
|
||||
buf = per_expert_bufs[e]
|
||||
buf.zero_()
|
||||
# Copy 128 rows starting at this expert's aligned offset
|
||||
buf[:, :K_sf] = padded_x_sf[e * 128:e * 128 + 128]
|
||||
@@ -243,7 +254,8 @@ class CuTeDSLMoERunner:
|
||||
)
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
x_sf, expert_offsets[:self.num_experts + 1]
|
||||
x_sf, expert_offsets[:self.num_experts + 1],
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = torch.full(
|
||||
(self.num_experts,), self._l1_activation_global_scale,
|
||||
@@ -268,7 +280,8 @@ class CuTeDSLMoERunner:
|
||||
)
|
||||
|
||||
l2_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
l2_x_sf, expert_offsets[:self.num_experts + 1]
|
||||
l2_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
|
||||
)
|
||||
l2_gsa = torch.full(
|
||||
(self.num_experts,), self._l2_activation_global_scale,
|
||||
|
||||
Reference in New Issue
Block a user