diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 78b51590..b7fd25e7 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -1,9 +1,17 @@ """ vLLM integration for the CuTeDSL NVFP4 MoE kernel. -CUDA-graph-compatible: no .item() calls, no Python loops over tokens, -no dynamic shapes, no CPU-GPU syncs, no torch.cuda.synchronize(). -All buffers pre-allocated at max_num_tokens size. +CUDA-graph-compatible design: +- All intermediate buffers pre-allocated at max_num_tokens * top_k size +- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs +- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers +- Extra slots (beyond real tokens) are zero and contribute nothing to output +- Fixed-shape tensors throughout the forward pass + +vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192). +During capture, num_tokens equals the budget — all shapes are fixed. +During replay, inputs are padded to the budget size. Our runner always +processes max_slots = budget * top_k rows; padding rows are zeros. """ import torch @@ -16,7 +24,6 @@ from cutedsl.bridge import ( ) from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( ceil_div as cutedsl_ceil_div, - round_up as cutedsl_round_up, pad_and_swizzle_single, ) @@ -24,7 +31,8 @@ from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( class CuTeDSLMoERunner: """Manages NVFP4 MoE execution via the CuTeDSL kernel. - CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs. + CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs, + no dynamic shapes. Always computes at max_num_tokens * top_k capacity. """ def __init__(self, num_experts, hidden_size, intermediate_size, @@ -36,6 +44,7 @@ class CuTeDSLMoERunner: self.top_k = top_k self.device = device + # Weight storage (set before _ensure_stacked) self.l1_fp4 = None self.l1_sf = None self.l1_gs = None @@ -43,6 +52,7 @@ class CuTeDSLMoERunner: self.l2_sf = None self.l2_gs = None + # Stacked weight tensors (set in _ensure_stacked) self._l1_mat_b = None self._l2_mat_b = None self._l1_scale_b = None @@ -53,12 +63,11 @@ class CuTeDSLMoERunner: self._l1_activation_global_scale = 1.0 / 2688.0 self._l2_activation_global_scale = 1.0 / 2688.0 - # Pre-allocated buffers (set in _allocate_buffers) + # Pre-allocated cudagraph buffers (set in _allocate_buffers) self._token_indices = None self._expert_id_range = None - self._output_buf = None - self._padded_scales_buf = None self._expert_offsets_buf = None + self._padded_scales_buf = None self._padded_expert_offsets_buf = None self._buffers_allocated = False @@ -67,8 +76,10 @@ class CuTeDSLMoERunner: max_slots = self.max_num_tokens * self.top_k K_sf = cutedsl_ceil_div(self.hidden_size, 16) padded_cols = cutedsl_ceil_div(K_sf, 4) * 4 - max_padded_rows = self.num_experts * 128 # worst case: 1 token per expert, each padded to 128 + # Worst case: 1 token per expert, each padded to 128 rows + max_padded_rows = self.num_experts * 128 + # Slot -> token mapping: [0,0,...,0, 1,1,...,1, ...] (top_k repeats) self._token_indices = torch.arange( self.max_num_tokens, device=self.device ).unsqueeze(1).expand(-1, self.top_k).reshape(-1) @@ -81,14 +92,9 @@ class CuTeDSLMoERunner: self._padded_expert_offsets_buf = torch.zeros( self.num_experts + 1, dtype=torch.int32, device=self.device ) - - self._output_buf = torch.zeros( - max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device - ) - self._padded_scales_buf = torch.zeros( - max_padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=self.device - ) + max_padded_rows, padded_cols, dtype=torch.float16, device=self.device + ).to(torch.float8_e4m3fn) self._buffers_allocated = True @@ -137,10 +143,8 @@ class CuTeDSLMoERunner: def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets): """Assemble 2D-side activation scales (cudagraph-safe, no CPU sync). - Pre-allocates a padded buffer at max size. Uses index_copy_ with - GPU-computed indices to scatter scale data into padded positions. - Then applies the swizzle to the whole buffer. - + Uses GPU-computed indices to scatter scale data into padded positions, + then applies the swizzle. Returns 2D tensor. No .item(), no .tolist(), no Python control flow on GPU data. """ num_experts = self.num_experts @@ -164,16 +168,12 @@ class CuTeDSLMoERunner: padded_scales = self._padded_scales_buf[:total_padded_rows, :padded_cols] padded_scales.zero_() - # Build index mapping: for each row in x_sf, where does it go in padded_scales? - # Row i in x_sf belongs to expert e where expert_offsets[e] <= i < expert_offsets[e+1] - # Its destination is padded_expert_offsets[e] + (i - expert_offsets[e]) - - # Use searchsorted to find which expert each row belongs to + # Build index mapping: for each row in x_sf, which expert does it belong to? total_rows = x_sf.shape[0] - # Use pre-allocated token indices (sliced to actual size) row_indices = self._token_indices[:total_rows] - # expert_assign[i] = which expert row i belongs to - expert_assign = torch.searchsorted(expert_offsets[1:], row_indices, right=False).clamp(max=num_experts - 1) + expert_assign = torch.searchsorted( + expert_offsets[1:], row_indices, right=False + ).clamp(max=num_experts - 1) # Destination row in padded buffer local_row = row_indices - expert_offsets[expert_assign] @@ -182,91 +182,111 @@ class CuTeDSLMoERunner: # Scatter x_sf into padded_scales padded_scales[dst_rows, :K_sf] = x_sf - # Apply swizzle to the whole padded tensor, return 2D for 2D-side scale_a - # to_blocked preserves element count, so reshape to match padded shape + # Apply swizzle, reshape to 2D (element count preserved by swizzle) swizzled = pad_and_swizzle_single(padded_scales) return swizzled.reshape(padded_scales.shape[0], -1) def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None): - num_tokens, hidden_size = hidden_states.shape + """Run the NVFP4 MoE forward pass. + + Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes. + + expert_offsets are computed from the actual token distribution + via GPU-only ops (argsort, broadcast ==, cumsum). These offsets + are passed to the GEMM as a GPU tensor, never converted to Python. + + The GEMM and quantize functions see the full slot buffer. + Padding rows are zeros that produce zero output, contributing + nothing to the final scatter_add. + + Args: + hidden_states: (num_tokens, hidden_size) bf16 + topk_weights: (num_tokens, top_k) float32 + topk_ids: (num_tokens, top_k) int + expert_indices: ignored (uses all experts) + + Returns: + (num_tokens, hidden_size) bf16 - MoE output + """ + num_tokens = hidden_states.shape[0] top_k = topk_ids.shape[1] device = hidden_states.device - if expert_indices is None: - expert_indices = list(range(self.num_experts)) - - num_experts = len(expert_indices) self._ensure_stacked() - # ── Build slot mapping ── + # -- Build slot mapping -- flat_ids = topk_ids.reshape(-1) flat_weights = topk_weights.reshape(-1) - token_indices = self._token_indices[:num_tokens * top_k] + num_slots = num_tokens * top_k + token_indices = self._token_indices[:num_slots] sort_idx = flat_ids.argsort(stable=True) sorted_ids = flat_ids[sort_idx] sorted_weights = flat_weights[sort_idx] sorted_token_ids = token_indices[sort_idx] - # Expert offsets (GPU-only) - expert_id_range = self._expert_id_range[:num_experts] + # Expert offsets (GPU-only, never touches CPU) + expert_id_range = self._expert_id_range tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() expert_offsets = self._expert_offsets_buf expert_offsets.zero_() - expert_offsets[1:num_experts + 1] = tokens_per_expert.cumsum(0) + expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0) - total_slots = expert_offsets[num_experts] + # -- Gather hidden states into slot order -- + slot_hidden = hidden_states[sorted_token_ids] - slot_hidden = hidden_states[sorted_token_ids[:total_slots]] - - # ════════════════════════════════════════════════════════════ - # L1: gate + up - # ════════════════════════════════════════════════════════════ + # === L1: gate + up === x_fp4, x_sf = quantize_activation_nvfp4( slot_hidden, self._l1_activation_global_scale ) - l1_scale_a = self._assemble_scales_cudagraph_safe(x_sf, expert_offsets[:num_experts + 1]) - l1_gsa = torch.full((num_experts,), self._l1_activation_global_scale, - dtype=torch.float32, device=device) + l1_scale_a = self._assemble_scales_cudagraph_safe( + x_sf, expert_offsets[:self.num_experts + 1] + ) + l1_gsa = torch.full( + (self.num_experts,), self._l1_activation_global_scale, + dtype=torch.float32, device=device + ) l1_out = run_nvfp4_grouped_gemm( mat_a=x_fp4, mat_b=self._l1_mat_b, scale_a=l1_scale_a, scale_b=self._l1_scale_b, - expert_offsets=expert_offsets[:num_experts + 1], + expert_offsets=expert_offsets[:self.num_experts + 1], global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, ) - # ════════════════════════════════════════════════════════════ - # SiLU(gate) * up - # ════════════════════════════════════════════════════════════ + # === SiLU(gate) * up === gate = l1_out[:, :self.intermediate_size] up = l1_out[:, self.intermediate_size:] activated = torch.nn.functional.silu(gate) * up - # ════════════════════════════════════════════════════════════ - # L2: down - # ════════════════════════════════════════════════════════════ + # === L2: down === l2_x_fp4, l2_x_sf = quantize_activation_nvfp4( activated, self._l2_activation_global_scale ) - l2_scale_a = self._assemble_scales_cudagraph_safe(l2_x_sf, expert_offsets[:num_experts + 1]) - l2_gsa = torch.full((num_experts,), self._l2_activation_global_scale, - dtype=torch.float32, device=device) + l2_scale_a = self._assemble_scales_cudagraph_safe( + l2_x_sf, expert_offsets[:self.num_experts + 1] + ) + l2_gsa = torch.full( + (self.num_experts,), self._l2_activation_global_scale, + dtype=torch.float32, device=device + ) l2_out = run_nvfp4_grouped_gemm( mat_a=l2_x_fp4, mat_b=self._l2_mat_b, scale_a=l2_scale_a, scale_b=self._l2_scale_b, - expert_offsets=expert_offsets[:num_experts + 1], + expert_offsets=expert_offsets[:self.num_experts + 1], global_scale_a=l2_gsa, global_scale_b=self._l2_gsb, ) - # ════════════════════════════════════════════════════════════ - # Scatter → final output - # ════════════════════════════════════════════════════════════ - y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) - weighted_out = l2_out * sorted_weights[:total_slots].unsqueeze(1).to(l2_out.dtype) - y.scatter_add_(0, sorted_token_ids[:total_slots].unsqueeze(1).expand(-1, hidden_size), weighted_out) + # === Scatter -> final output === + y = torch.zeros(num_tokens, self.hidden_size, dtype=torch.bfloat16, device=device) + weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype) + y.scatter_add_( + 0, + sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size), + weighted_out, + ) return y