From bde81b95f4a43e63f0f1c97e196e8407d3b682d4 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 13:19:31 +0000 Subject: [PATCH] Fix GEMM scale layout: pad to 128 tokens per expert Root cause of garbage output: the GEMM reads scale_a according to expert_offsets (e.g. [0, 500, 1024, ...]) but scale_a had data at fixed e*128 offsets. When expert 0 has 500 tokens, the GEMM reads scale_a[0:500] but only rows 0-127 had valid data. Fix: pad slot_hidden to num_experts*128 rows (128 per expert) and pass padded_expert_offsets=[0, 128, 256, ...] to the GEMM. Scale assembly's fixed 128-row layout now matches the GEMM's expectations. Padding tokens' GEMM output is discarded (scatter_add only uses sorted_token_ids for real tokens). --- vllm/nvfp4_cutedsl.py | 66 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 9 deletions(-) diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index de93657b..8c7024d2 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -80,6 +80,9 @@ class CuTeDSLMoERunner: self._l2_gsa_buf = None self._output_buf = None self._row_indices_buf = None + self._padded_hidden_buf = None + self._padded_activated_buf = None + self._padded_expert_offsets_buf = None self._buffers_allocated = False def _fill_token_indices(self): @@ -132,6 +135,23 @@ class CuTeDSLMoERunner: self.max_num_tokens * self.top_k, device=self.device ) + # Padded hidden/activated buffers: num_experts * 128 rows + padded_size = self.num_experts * 128 + self._padded_hidden_buf = torch.zeros( + padded_size, self.hidden_size, dtype=torch.bfloat16, device=self.device + ) + self._padded_activated_buf = torch.zeros( + padded_size, self.intermediate_size, dtype=torch.bfloat16, device=self.device + ) + + # Padded expert offsets: [0, 128, 256, ...] (fixed, 128 tokens per expert) + self._padded_expert_offsets_buf = torch.zeros( + self.num_experts + 1, dtype=torch.int32, device=self.device + ) + self._padded_expert_offsets_buf[1:] = torch.arange( + 1, self.num_experts + 1, dtype=torch.int32, device=self.device + ) * 128 + self._buffers_allocated = True def _ensure_stacked(self): @@ -333,6 +353,12 @@ class CuTeDSLMoERunner: (harmless — zero-weighted output contributes nothing to scatter_add). Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes. + + PADDING STRATEGY: Each expert is padded to exactly 128 tokens in + the GEMM input. This ensures scale_a has a fixed layout (128 rows + per expert) that matches the padded expert_offsets. The GEMM output + for padding tokens is discarded by scatter_add (sorted_token_ids + only refers to real tokens). """ num_tokens = hidden_states.shape[0] top_k = topk_ids.shape[1] @@ -341,8 +367,6 @@ class CuTeDSLMoERunner: self._ensure_stacked() # -- Remap global expert IDs to local IDs -- - # topk_ids are global: remap by subtracting experts_start_idx. - # Tokens for non-local experts get clamped to 0 with zero weight. local_ids = topk_ids - self.experts_start_idx local_mask = (local_ids >= 0) & (local_ids < self.num_experts) safe_ids = local_ids.clamp(0, self.num_experts - 1) @@ -366,16 +390,28 @@ class CuTeDSLMoERunner: expert_offsets.zero_() expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0) - # -- Gather hidden states into slot order -- + # -- Pad to 128 tokens per expert -- + # The GEMM and scale assembly expect a fixed layout where each expert + # gets exactly 128 rows. Pad slot_hidden to num_experts * 128 rows, + # and set expert_offsets to [0, 128, 256, ...]. + padded_expert_offsets = self._padded_expert_offsets_buf + # Already filled in _allocate_buffers with [0, 128, 256, ...] + + # Gather hidden states into slot order, then pad to 128 per expert slot_hidden = hidden_states[sorted_token_ids] + padded_size = self.num_experts * 128 + padded_hidden = self._padded_hidden_buf[:padded_size] + padded_hidden.zero_() + padded_hidden[:num_slots] = slot_hidden # === L1: gate + up === x_fp4, x_sf = quantize_activation_nvfp4( - slot_hidden, self._l1_activation_global_scale + padded_hidden, self._l1_activation_global_scale ) l1_scale_a = self._assemble_scales_cudagraph_safe( - x_sf, expert_offsets[:self.num_experts + 1], + x_sf[:num_slots], # Only real tokens have scale data + expert_offsets[:self.num_experts + 1], self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1 ) l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale) @@ -383,22 +419,31 @@ class CuTeDSLMoERunner: l1_out = run_nvfp4_grouped_gemm( mat_a=x_fp4, mat_b=self._l1_mat_b, scale_a=l1_scale_a, scale_b=self._l1_scale_b, - expert_offsets=expert_offsets[1:self.num_experts + 1], + expert_offsets=padded_expert_offsets[1:self.num_experts + 1], global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, ) + # Only keep real token outputs (padding rows are garbage) + l1_out = l1_out[:num_slots] + # === SiLU(gate) * up === gate = l1_out[:, :self.intermediate_size] up = l1_out[:, self.intermediate_size:] activated = torch.nn.functional.silu(gate) * up # === L2: down === + # Pad activated to 128 tokens per expert for L2 GEMM + padded_activated = self._padded_activated_buf[:padded_size] + padded_activated.zero_() + padded_activated[:num_slots] = activated + l2_x_fp4, l2_x_sf = quantize_activation_nvfp4( - activated, self._l2_activation_global_scale + padded_activated, self._l2_activation_global_scale ) l2_scale_a = self._assemble_scales_cudagraph_safe( - l2_x_sf, expert_offsets[:self.num_experts + 1], + l2_x_sf[:num_slots], + expert_offsets[:self.num_experts + 1], self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2 ) l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale) @@ -406,10 +451,13 @@ class CuTeDSLMoERunner: l2_out = run_nvfp4_grouped_gemm( mat_a=l2_x_fp4, mat_b=self._l2_mat_b, scale_a=l2_scale_a, scale_b=self._l2_scale_b, - expert_offsets=expert_offsets[1:self.num_experts + 1], + expert_offsets=padded_expert_offsets[1:self.num_experts + 1], global_scale_a=l2_gsa, global_scale_b=self._l2_gsb, ) + # Only keep real token outputs + l2_out = l2_out[:num_slots] + # === Scatter -> final output === y = self._output_buf[:num_tokens] y.zero_()