diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 427599e6..e6b4567b 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -109,22 +109,7 @@ class CuTeDSLMoERunner: if self._l1_mat_b is not None: return - # Allocate token indices FIRST, before any CUDA JIT compilation - # (CuTeDSL's cute.compile can corrupt GPU memory state during JIT) - # Use .clone() to ensure a fresh allocation and cuda.synchronize() to - # ensure the kernel completes before any JIT compilation starts - if self._token_indices is None: - self._token_indices = torch.arange( - self.max_num_tokens, device=self.device, dtype=torch.int32 - ).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).clone() - torch.cuda.synchronize() - if self._expert_id_range is None: - self._expert_id_range = torch.arange(self.num_experts, device=self.device) - if self._expert_offsets_buf is None: - self._expert_offsets_buf = torch.zeros( - self.num_experts + 1, dtype=torch.int32, device=self.device - ) - + # Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation) self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4)) self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) @@ -137,6 +122,19 @@ class CuTeDSLMoERunner: self.l2_fp4 = None self.l2_sf = None self.l2_gs = None + + # Allocate buffers AFTER JIT compilation + # (CuTeDSL's cute.compile corrupts GPU memory during JIT; + # tensors allocated before/during compilation may be zeroed) + self._token_indices = torch.arange( + self.max_num_tokens, dtype=torch.int32 + ).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).to(self.device) + self._expert_id_range = torch.arange( + self.num_experts, dtype=torch.int32 + ).to(self.device) + self._expert_offsets_buf = torch.zeros( + self.num_experts + 1, dtype=torch.int32, device=self.device + ) self._allocate_buffers() def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):