Fix GEMM scale layout: pad to 128 tokens per expert
Root cause of garbage output: the GEMM reads scale_a according to expert_offsets (e.g. [0, 500, 1024, ...]) but scale_a had data at fixed e*128 offsets. When expert 0 has 500 tokens, the GEMM reads scale_a[0:500] but only rows 0-127 had valid data. Fix: pad slot_hidden to num_experts*128 rows (128 per expert) and pass padded_expert_offsets=[0, 128, 256, ...] to the GEMM. Scale assembly's fixed 128-row layout now matches the GEMM's expectations. Padding tokens' GEMM output is discarded (scatter_add only uses sorted_token_ids for real tokens).
This commit is contained in:
@@ -80,6 +80,9 @@ class CuTeDSLMoERunner:
|
||||
self._l2_gsa_buf = None
|
||||
self._output_buf = None
|
||||
self._row_indices_buf = None
|
||||
self._padded_hidden_buf = None
|
||||
self._padded_activated_buf = None
|
||||
self._padded_expert_offsets_buf = None
|
||||
self._buffers_allocated = False
|
||||
|
||||
def _fill_token_indices(self):
|
||||
@@ -132,6 +135,23 @@ class CuTeDSLMoERunner:
|
||||
self.max_num_tokens * self.top_k, device=self.device
|
||||
)
|
||||
|
||||
# Padded hidden/activated buffers: num_experts * 128 rows
|
||||
padded_size = self.num_experts * 128
|
||||
self._padded_hidden_buf = torch.zeros(
|
||||
padded_size, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
self._padded_activated_buf = torch.zeros(
|
||||
padded_size, self.intermediate_size, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
|
||||
# Padded expert offsets: [0, 128, 256, ...] (fixed, 128 tokens per expert)
|
||||
self._padded_expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._padded_expert_offsets_buf[1:] = torch.arange(
|
||||
1, self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
) * 128
|
||||
|
||||
self._buffers_allocated = True
|
||||
|
||||
def _ensure_stacked(self):
|
||||
@@ -333,6 +353,12 @@ class CuTeDSLMoERunner:
|
||||
(harmless — zero-weighted output contributes nothing to scatter_add).
|
||||
|
||||
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
|
||||
|
||||
PADDING STRATEGY: Each expert is padded to exactly 128 tokens in
|
||||
the GEMM input. This ensures scale_a has a fixed layout (128 rows
|
||||
per expert) that matches the padded expert_offsets. The GEMM output
|
||||
for padding tokens is discarded by scatter_add (sorted_token_ids
|
||||
only refers to real tokens).
|
||||
"""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
@@ -341,8 +367,6 @@ class CuTeDSLMoERunner:
|
||||
self._ensure_stacked()
|
||||
|
||||
# -- Remap global expert IDs to local IDs --
|
||||
# topk_ids are global: remap by subtracting experts_start_idx.
|
||||
# Tokens for non-local experts get clamped to 0 with zero weight.
|
||||
local_ids = topk_ids - self.experts_start_idx
|
||||
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
|
||||
safe_ids = local_ids.clamp(0, self.num_experts - 1)
|
||||
@@ -366,16 +390,28 @@ class CuTeDSLMoERunner:
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
# -- Gather hidden states into slot order --
|
||||
# -- Pad to 128 tokens per expert --
|
||||
# The GEMM and scale assembly expect a fixed layout where each expert
|
||||
# gets exactly 128 rows. Pad slot_hidden to num_experts * 128 rows,
|
||||
# and set expert_offsets to [0, 128, 256, ...].
|
||||
padded_expert_offsets = self._padded_expert_offsets_buf
|
||||
# Already filled in _allocate_buffers with [0, 128, 256, ...]
|
||||
|
||||
# Gather hidden states into slot order, then pad to 128 per expert
|
||||
slot_hidden = hidden_states[sorted_token_ids]
|
||||
padded_size = self.num_experts * 128
|
||||
padded_hidden = self._padded_hidden_buf[:padded_size]
|
||||
padded_hidden.zero_()
|
||||
padded_hidden[:num_slots] = slot_hidden
|
||||
|
||||
# === L1: gate + up ===
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
slot_hidden, self._l1_activation_global_scale
|
||||
padded_hidden, self._l1_activation_global_scale
|
||||
)
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
x_sf, expert_offsets[:self.num_experts + 1],
|
||||
x_sf[:num_slots], # Only real tokens have scale data
|
||||
expert_offsets[:self.num_experts + 1],
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
|
||||
@@ -383,22 +419,31 @@ class CuTeDSLMoERunner:
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4, mat_b=self._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=expert_offsets[1:self.num_experts + 1],
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
)
|
||||
|
||||
# Only keep real token outputs (padding rows are garbage)
|
||||
l1_out = l1_out[:num_slots]
|
||||
|
||||
# === SiLU(gate) * up ===
|
||||
gate = l1_out[:, :self.intermediate_size]
|
||||
up = l1_out[:, self.intermediate_size:]
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
|
||||
# === L2: down ===
|
||||
# Pad activated to 128 tokens per expert for L2 GEMM
|
||||
padded_activated = self._padded_activated_buf[:padded_size]
|
||||
padded_activated.zero_()
|
||||
padded_activated[:num_slots] = activated
|
||||
|
||||
l2_x_fp4, l2_x_sf = quantize_activation_nvfp4(
|
||||
activated, self._l2_activation_global_scale
|
||||
padded_activated, self._l2_activation_global_scale
|
||||
)
|
||||
|
||||
l2_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
l2_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
l2_x_sf[:num_slots],
|
||||
expert_offsets[:self.num_experts + 1],
|
||||
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
|
||||
)
|
||||
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
|
||||
@@ -406,10 +451,13 @@ class CuTeDSLMoERunner:
|
||||
l2_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=l2_x_fp4, mat_b=self._l2_mat_b,
|
||||
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
|
||||
expert_offsets=expert_offsets[1:self.num_experts + 1],
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
||||
)
|
||||
|
||||
# Only keep real token outputs
|
||||
l2_out = l2_out[:num_slots]
|
||||
|
||||
# === Scatter -> final output ===
|
||||
y = self._output_buf[:num_tokens]
|
||||
y.zero_()
|
||||
|
||||
Reference in New Issue
Block a user