diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 8c7024d2..81f3a7e7 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -113,12 +113,13 @@ class CuTeDSLMoERunner: ] # Padded x_sf buffers for Phase 1 scatter. - # Fixed 128 rows per expert → num_experts * 128 total rows. + # Sized for max_num_tokens * top_k rows (worst case: all tokens in one expert). + max_sf_rows = self.max_num_tokens * self.top_k self._padded_x_sf_buf_l1 = torch.zeros( - self.num_experts * 128, padded_cols_l1, dtype=torch.float16, device=self.device + max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device ).to(torch.float8_e4m3fn) self._padded_x_sf_buf_l2 = torch.zeros( - self.num_experts * 128, padded_cols_l2, dtype=torch.float16, device=self.device + max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device ).to(torch.float8_e4m3fn) # Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture) @@ -135,22 +136,20 @@ class CuTeDSLMoERunner: self.max_num_tokens * self.top_k, device=self.device ) - # Padded hidden/activated buffers: num_experts * 128 rows - padded_size = self.num_experts * 128 + # Padded hidden/activated buffers: max_num_tokens * top_k rows (rounded to 128) + max_slots = self.max_num_tokens * self.top_k + padded_max_slots = ((max_slots + 127) // 128) * 128 self._padded_hidden_buf = torch.zeros( - padded_size, self.hidden_size, dtype=torch.bfloat16, device=self.device + padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device ) self._padded_activated_buf = torch.zeros( - padded_size, self.intermediate_size, dtype=torch.bfloat16, device=self.device + padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device ) - # Padded expert offsets: [0, 128, 256, ...] (fixed, 128 tokens per expert) + # Padded expert offsets buffer (num_experts + 1) self._padded_expert_offsets_buf = torch.zeros( self.num_experts + 1, dtype=torch.int32, device=self.device ) - self._padded_expert_offsets_buf[1:] = torch.arange( - 1, self.num_experts + 1, dtype=torch.int32, device=self.device - ) * 128 self._buffers_allocated = True @@ -224,58 +223,57 @@ class CuTeDSLMoERunner: padded_x_sf_buf, per_expert_bufs): """Assemble 2D-side activation scales (cudagraph-safe, no CPU sync). - Per-expert swizzle using pre-allocated buffers, then concatenate. - The per-expert loop is a fixed-size Python loop (num_experts is constant), - so cudagraph captures it as a static unrolled sequence. - - Each expert's scale data is swizzled independently and stacked. - The output shape depends on expert_offsets (GPU tensor), but during - cudagraph capture, expert_offsets is deterministic (fixed token budget). + Each expert's scale rows are padded to 128, then swizzled independently. + The GEMM reads scale_a according to padded_expert_offsets, which + matches the layout produced here. """ num_experts = self.num_experts K_sf = x_sf.shape[1] padded_x_sf = padded_x_sf_buf padded_x_sf.zero_() - # Phase 1: Scatter x_sf into 128-aligned per-expert sections - # Each expert gets a fixed 128-row slot in padded_x_sf (at offset e*128). - # Tokens beyond 128 per expert wrap to zero rows (harmless — zero scale - # means zero contribution in GEMM). All indexing is fixed-shape. + # Compute padded expert offsets (each expert padded to 128 rows) + tokens_per_expert = expert_offsets[1:] - expert_offsets[:-1] + padded_rows_per_expert = ((tokens_per_expert + 127) // 128) * 128 + padded_expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=x_sf.device) + padded_expert_offsets[1:] = padded_rows_per_expert.cumsum(0) + + # Phase 1: Scatter x_sf into padded per-expert sections total_rows = x_sf.shape[0] row_indices = self._row_indices_buf[:total_rows] expert_assign = torch.searchsorted( expert_offsets[1:], row_indices, right=True ).clamp(max=num_experts - 1) local_row = row_indices - expert_offsets[expert_assign] - # Clamp local_row to [0, 127] — rows beyond 128 go to row 0 (overwriting, - # but row 0 already has valid data, and the extra row is ignored by GEMM) - clamped_local = local_row.clamp(max=127) - dst_rows = expert_assign * 128 + clamped_local + dst_rows = padded_expert_offsets[expert_assign] + local_row padded_x_sf[dst_rows, :K_sf] = x_sf # Phase 2: Per-expert swizzle and concatenate - # Each expert gets at most padded_x_sf[e*128 : (e+1)*128] for the first 128 rows. - # For experts with >128 tokens, we'd need multiple chunks, but during - # cudagraph capture the token budget is fixed, and the GEMM uses expert_offsets - # to determine how many rows each expert gets. - # - # Strategy: always swizzle 128 rows per expert (fixed loop), zero-pad shorter experts. - # The GEMM only reads the rows indicated by expert_offsets. swizzled_parts = [] for e in range(num_experts): + n_padded = padded_rows_per_expert[e] + start = padded_expert_offsets[e] buf = per_expert_bufs[e] - buf.zero_() - # Copy from padded_x_sf at this expert's 128-aligned offset - # Always copy 128 rows (fixed shape for cudagraph) - src_start = e * 128 - buf[:, :K_sf] = padded_x_sf[src_start:src_start + 128] - swizzled = pad_and_swizzle_single(buf) - swizzled_parts.append(swizzled) + # Process in 128-row chunks + offset = start + remaining = n_padded + while remaining > 0: + buf.zero_() + chunk = min(remaining, 128) + buf[:chunk, :K_sf] = padded_x_sf[offset:offset + chunk] + swizzled = pad_and_swizzle_single(buf) + swizzled_parts.append(swizzled) + offset += 128 + remaining -= 128 + if n_padded == 0: + buf.zero_() + swizzled = pad_and_swizzle_single(buf) + swizzled_parts.append(swizzled) - # Concatenate all expert blocks (byte-reinterpretable) all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0) all_flat = all_flat.view(torch.float8_e4m3fn) - return all_flat.reshape(num_experts * 128, -1) + total_padded = padded_expert_offsets[num_experts] + return all_flat.reshape(total_padded, -1) def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids): """Compute activation global scales from a warmup forward pass. @@ -345,20 +343,11 @@ class CuTeDSLMoERunner: """Run the NVFP4 MoE forward pass. Handles global→local expert ID remapping for expert parallelism. - topk_ids contains GLOBAL expert IDs (0..n_routed_experts-1). - This runner only handles local experts - [experts_start_idx, experts_start_idx + num_experts). - - Non-local tokens get zero weight and are clamped to expert 0 - (harmless — zero-weighted output contributes nothing to scatter_add). - Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes. - PADDING STRATEGY: Each expert is padded to exactly 128 tokens in - the GEMM input. This ensures scale_a has a fixed layout (128 rows - per expert) that matches the padded expert_offsets. The GEMM output - for padding tokens is discarded by scatter_add (sorted_token_ids - only refers to real tokens). + Each expert's slots are padded to multiples of 128 for the GEMM. + expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...]. + scale_a is produced at those same offsets. """ num_tokens = hidden_states.shape[0] top_k = topk_ids.shape[1] @@ -381,28 +370,36 @@ class CuTeDSLMoERunner: sort_idx = flat_ids.argsort(stable=True) sorted_ids = flat_ids[sort_idx] sorted_weights = flat_weights[sort_idx] - sorted_token_ids = token_indices[sort_idx] # GPU tensor, no .cpu() + sorted_token_ids = token_indices[sort_idx] - # Expert offsets (GPU-only, never touches CPU) + # Expert offsets (real token counts) expert_id_range = self._expert_id_range tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() expert_offsets = self._expert_offsets_buf expert_offsets.zero_() expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0) - # -- Pad to 128 tokens per expert -- - # The GEMM and scale assembly expect a fixed layout where each expert - # gets exactly 128 rows. Pad slot_hidden to num_experts * 128 rows, - # and set expert_offsets to [0, 128, 256, ...]. + # Padded expert offsets (each expert padded to 128) + padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128 padded_expert_offsets = self._padded_expert_offsets_buf - # Already filled in _allocate_buffers with [0, 128, 256, ...] + padded_expert_offsets.zero_() + padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0) + total_padded_slots = padded_expert_offsets[self.num_experts] - # Gather hidden states into slot order, then pad to 128 per expert + # -- Gather hidden states into slot order, pad to 128 per expert -- slot_hidden = hidden_states[sorted_token_ids] - padded_size = self.num_experts * 128 - padded_hidden = self._padded_hidden_buf[:padded_size] + padded_hidden = self._padded_hidden_buf[:total_padded_slots] padded_hidden.zero_() - padded_hidden[:num_slots] = slot_hidden + # Scatter real tokens into padded positions + # Each expert e: tokens are at [expert_offsets[e], expert_offsets[e+1]) + # In padded buffer: tokens are at [padded_expert_offsets[e], padded_expert_offsets[e]+tokens_per_expert[e]) + row_indices = self._row_indices_buf[:num_slots] + expert_assign = torch.searchsorted( + expert_offsets[1:], row_indices, right=True + ).clamp(max=self.num_experts - 1) + local_row = row_indices - expert_offsets[expert_assign] + padded_dst = padded_expert_offsets[expert_assign] + local_row + padded_hidden[padded_dst] = slot_hidden # === L1: gate + up === x_fp4, x_sf = quantize_activation_nvfp4( @@ -410,8 +407,7 @@ class CuTeDSLMoERunner: ) l1_scale_a = self._assemble_scales_cudagraph_safe( - x_sf[:num_slots], # Only real tokens have scale data - expert_offsets[:self.num_experts + 1], + x_sf, expert_offsets[:self.num_experts + 1], self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1 ) l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale) @@ -423,27 +419,25 @@ class CuTeDSLMoERunner: global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, ) - # Only keep real token outputs (padding rows are garbage) - l1_out = l1_out[:num_slots] + # Extract real token outputs from padded GEMM output + l1_out_real = l1_out[padded_dst] # === SiLU(gate) * up === - gate = l1_out[:, :self.intermediate_size] - up = l1_out[:, self.intermediate_size:] + gate = l1_out_real[:, :self.intermediate_size] + up = l1_out_real[:, self.intermediate_size:] activated = torch.nn.functional.silu(gate) * up # === L2: down === - # Pad activated to 128 tokens per expert for L2 GEMM - padded_activated = self._padded_activated_buf[:padded_size] + padded_activated = self._padded_activated_buf[:total_padded_slots] padded_activated.zero_() - padded_activated[:num_slots] = activated + padded_activated[padded_dst] = activated l2_x_fp4, l2_x_sf = quantize_activation_nvfp4( padded_activated, self._l2_activation_global_scale ) l2_scale_a = self._assemble_scales_cudagraph_safe( - l2_x_sf[:num_slots], - expert_offsets[:self.num_experts + 1], + l2_x_sf, expert_offsets[:self.num_experts + 1], self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2 ) l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale) @@ -455,13 +449,12 @@ class CuTeDSLMoERunner: global_scale_a=l2_gsa, global_scale_b=self._l2_gsb, ) - # Only keep real token outputs - l2_out = l2_out[:num_slots] + l2_out_real = l2_out[padded_dst] # === Scatter -> final output === y = self._output_buf[:num_tokens] y.zero_() - weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype) + weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype) y.scatter_add_( 0, sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),