From 7073daaffa59c7fcd31e856701c03546ee9456b2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 08:22:51 +0000 Subject: [PATCH] fix: allocate token_indices on CPU, move to GPU AFTER JIT compilation CuTeDSL's cute.compile corrupts GPU memory during JIT compilation. Tensors allocated on GPU before/during compilation get zeroed. Fix: create token_indices on CPU, then .to(device) after JIT is done. --- vllm/nvfp4_cutedsl.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 427599e6..e6b4567b 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -109,22 +109,7 @@ class CuTeDSLMoERunner: if self._l1_mat_b is not None: return - # Allocate token indices FIRST, before any CUDA JIT compilation - # (CuTeDSL's cute.compile can corrupt GPU memory state during JIT) - # Use .clone() to ensure a fresh allocation and cuda.synchronize() to - # ensure the kernel completes before any JIT compilation starts - if self._token_indices is None: - self._token_indices = torch.arange( - self.max_num_tokens, device=self.device, dtype=torch.int32 - ).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).clone() - torch.cuda.synchronize() - if self._expert_id_range is None: - self._expert_id_range = torch.arange(self.num_experts, device=self.device) - if self._expert_offsets_buf is None: - self._expert_offsets_buf = torch.zeros( - self.num_experts + 1, dtype=torch.int32, device=self.device - ) - + # Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation) self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4)) self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) @@ -137,6 +122,19 @@ class CuTeDSLMoERunner: self.l2_fp4 = None self.l2_sf = None self.l2_gs = None + + # Allocate buffers AFTER JIT compilation + # (CuTeDSL's cute.compile corrupts GPU memory during JIT; + # tensors allocated before/during compilation may be zeroed) + self._token_indices = torch.arange( + self.max_num_tokens, dtype=torch.int32 + ).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).to(self.device) + self._expert_id_range = torch.arange( + self.num_experts, dtype=torch.int32 + ).to(self.device) + self._expert_offsets_buf = torch.zeros( + self.num_experts + 1, dtype=torch.int32, device=self.device + ) self._allocate_buffers() def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):