From d635dcbbb6ac5f6a68b37aa09baa322af6ab9afe Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 08:29:18 +0000 Subject: [PATCH] fix: keep token_indices on CPU, index with CPU sort_idx CuTeDSL's cute.compile corrupts GPU memory during JIT compilation. Keeping token_indices on CPU and using sort_idx.cpu() for indexing avoids the corruption. The .to(device) call after indexing moves the result back to GPU for the hidden_states indexing. --- vllm/nvfp4_cutedsl.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index e662e930..196d5781 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -129,17 +129,9 @@ class CuTeDSLMoERunner: # tensors allocated before/during compilation may be zeroed) self._token_indices = torch.arange( self.max_num_tokens, dtype=torch.int32 - ).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).to(self.device) - # Verify the tensor was transferred correctly (CuTeDSL JIT can corrupt GPU state) - _verify = self._token_indices[:8].cpu().tolist() - if _verify != [0, 0, 1, 1, 2, 2, 3, 3]: - # Fallback: allocate directly on GPU - self._token_indices = torch.zeros( - self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device - ) - for i in range(self.max_num_tokens): - self._token_indices[i * self.top_k] = i - self._token_indices[i * self.top_k + 1] = i + ).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1) + # Keep on CPU to avoid CuTeDSL JIT GPU memory corruption + # Will be indexed with CPU offsets during slot mapping self._expert_id_range = torch.arange( self.num_experts, dtype=torch.int32 ).to(self.device) @@ -247,7 +239,7 @@ class CuTeDSLMoERunner: token_indices = self._token_indices[:num_slots] sort_idx = flat_ids.argsort(stable=True) sorted_ids = flat_ids[sort_idx] - sorted_token_ids = token_indices[sort_idx] + sorted_token_ids = token_indices[sort_idx.cpu()].to(device) slot_hidden = hidden_states_sample[sorted_token_ids] # Debug: verify slot_hidden @@ -329,7 +321,7 @@ class CuTeDSLMoERunner: sort_idx = flat_ids.argsort(stable=True) sorted_ids = flat_ids[sort_idx] sorted_weights = flat_weights[sort_idx] - sorted_token_ids = token_indices[sort_idx] + sorted_token_ids = token_indices[sort_idx.cpu()].to(device) # Expert offsets (GPU-only, never touches CPU) expert_id_range = self._expert_id_range