Fix GEMM scale layout: pad to 128 tokens per expert

Root cause of garbage output: the GEMM reads scale_a according to
expert_offsets (e.g. [0, 500, 1024, ...]) but scale_a had data at
fixed e*128 offsets. When expert 0 has 500 tokens, the GEMM reads
scale_a[0:500] but only rows 0-127 had valid data.

Fix: pad slot_hidden to num_experts*128 rows (128 per expert) and
pass padded_expert_offsets=[0, 128, 256, ...] to the GEMM. Scale
assembly's fixed 128-row layout now matches the GEMM's expectations.
Padding tokens' GEMM output is discarded (scatter_add only uses
sorted_token_ids for real tokens).
This commit is contained in:
2026-05-17 13:19:31 +00:00
parent 7e692c3aec
commit bde81b95f4

View File

@@ -80,6 +80,9 @@ class CuTeDSLMoERunner:
self._l2_gsa_buf = None
self._output_buf = None
self._row_indices_buf = None
self._padded_hidden_buf = None
self._padded_activated_buf = None
self._padded_expert_offsets_buf = None
self._buffers_allocated = False
def _fill_token_indices(self):
@@ -132,6 +135,23 @@ class CuTeDSLMoERunner:
self.max_num_tokens * self.top_k, device=self.device
)
# Padded hidden/activated buffers: num_experts * 128 rows
padded_size = self.num_experts * 128
self._padded_hidden_buf = torch.zeros(
padded_size, self.hidden_size, dtype=torch.bfloat16, device=self.device
)
self._padded_activated_buf = torch.zeros(
padded_size, self.intermediate_size, dtype=torch.bfloat16, device=self.device
)
# Padded expert offsets: [0, 128, 256, ...] (fixed, 128 tokens per expert)
self._padded_expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
self._padded_expert_offsets_buf[1:] = torch.arange(
1, self.num_experts + 1, dtype=torch.int32, device=self.device
) * 128
self._buffers_allocated = True
def _ensure_stacked(self):
@@ -333,6 +353,12 @@ class CuTeDSLMoERunner:
(harmless — zero-weighted output contributes nothing to scatter_add).
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
PADDING STRATEGY: Each expert is padded to exactly 128 tokens in
the GEMM input. This ensures scale_a has a fixed layout (128 rows
per expert) that matches the padded expert_offsets. The GEMM output
for padding tokens is discarded by scatter_add (sorted_token_ids
only refers to real tokens).
"""
num_tokens = hidden_states.shape[0]
top_k = topk_ids.shape[1]
@@ -341,8 +367,6 @@ class CuTeDSLMoERunner:
self._ensure_stacked()
# -- Remap global expert IDs to local IDs --
# topk_ids are global: remap by subtracting experts_start_idx.
# Tokens for non-local experts get clamped to 0 with zero weight.
local_ids = topk_ids - self.experts_start_idx
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
safe_ids = local_ids.clamp(0, self.num_experts - 1)
@@ -366,16 +390,28 @@ class CuTeDSLMoERunner:
expert_offsets.zero_()
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
# -- Gather hidden states into slot order --
# -- Pad to 128 tokens per expert --
# The GEMM and scale assembly expect a fixed layout where each expert
# gets exactly 128 rows. Pad slot_hidden to num_experts * 128 rows,
# and set expert_offsets to [0, 128, 256, ...].
padded_expert_offsets = self._padded_expert_offsets_buf
# Already filled in _allocate_buffers with [0, 128, 256, ...]
# Gather hidden states into slot order, then pad to 128 per expert
slot_hidden = hidden_states[sorted_token_ids]
padded_size = self.num_experts * 128
padded_hidden = self._padded_hidden_buf[:padded_size]
padded_hidden.zero_()
padded_hidden[:num_slots] = slot_hidden
# === L1: gate + up ===
x_fp4, x_sf = quantize_activation_nvfp4(
slot_hidden, self._l1_activation_global_scale
padded_hidden, self._l1_activation_global_scale
)
l1_scale_a = self._assemble_scales_cudagraph_safe(
x_sf, expert_offsets[:self.num_experts + 1],
x_sf[:num_slots], # Only real tokens have scale data
expert_offsets[:self.num_experts + 1],
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
@@ -383,22 +419,31 @@ class CuTeDSLMoERunner:
l1_out = run_nvfp4_grouped_gemm(
mat_a=x_fp4, mat_b=self._l1_mat_b,
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=expert_offsets[1:self.num_experts + 1],
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
)
# Only keep real token outputs (padding rows are garbage)
l1_out = l1_out[:num_slots]
# === SiLU(gate) * up ===
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
activated = torch.nn.functional.silu(gate) * up
# === L2: down ===
# Pad activated to 128 tokens per expert for L2 GEMM
padded_activated = self._padded_activated_buf[:padded_size]
padded_activated.zero_()
padded_activated[:num_slots] = activated
l2_x_fp4, l2_x_sf = quantize_activation_nvfp4(
activated, self._l2_activation_global_scale
padded_activated, self._l2_activation_global_scale
)
l2_scale_a = self._assemble_scales_cudagraph_safe(
l2_x_sf, expert_offsets[:self.num_experts + 1],
l2_x_sf[:num_slots],
expert_offsets[:self.num_experts + 1],
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
)
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
@@ -406,10 +451,13 @@ class CuTeDSLMoERunner:
l2_out = run_nvfp4_grouped_gemm(
mat_a=l2_x_fp4, mat_b=self._l2_mat_b,
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
expert_offsets=expert_offsets[1:self.num_experts + 1],
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
)
# Only keep real token outputs
l2_out = l2_out[:num_slots]
# === Scatter -> final output ===
y = self._output_buf[:num_tokens]
y.zero_()