fix: keep token_indices on CPU, index with CPU sort_idx

CuTeDSL's cute.compile corrupts GPU memory during JIT compilation.
Keeping token_indices on CPU and using sort_idx.cpu() for indexing
avoids the corruption. The .to(device) call after indexing moves the
result back to GPU for the hidden_states indexing.
This commit is contained in:
2026-05-17 08:29:18 +00:00
parent 235d5b314f
commit d635dcbbb6

View File

@@ -129,17 +129,9 @@ class CuTeDSLMoERunner:
# tensors allocated before/during compilation may be zeroed)
self._token_indices = torch.arange(
self.max_num_tokens, dtype=torch.int32
).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).to(self.device)
# Verify the tensor was transferred correctly (CuTeDSL JIT can corrupt GPU state)
_verify = self._token_indices[:8].cpu().tolist()
if _verify != [0, 0, 1, 1, 2, 2, 3, 3]:
# Fallback: allocate directly on GPU
self._token_indices = torch.zeros(
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
)
for i in range(self.max_num_tokens):
self._token_indices[i * self.top_k] = i
self._token_indices[i * self.top_k + 1] = i
).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
# Keep on CPU to avoid CuTeDSL JIT GPU memory corruption
# Will be indexed with CPU offsets during slot mapping
self._expert_id_range = torch.arange(
self.num_experts, dtype=torch.int32
).to(self.device)
@@ -247,7 +239,7 @@ class CuTeDSLMoERunner:
token_indices = self._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_token_ids = token_indices[sort_idx]
sorted_token_ids = token_indices[sort_idx.cpu()].to(device)
slot_hidden = hidden_states_sample[sorted_token_ids]
# Debug: verify slot_hidden
@@ -329,7 +321,7 @@ class CuTeDSLMoERunner:
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_weights = flat_weights[sort_idx]
sorted_token_ids = token_indices[sort_idx]
sorted_token_ids = token_indices[sort_idx.cpu()].to(device)
# Expert offsets (GPU-only, never touches CPU)
expert_id_range = self._expert_id_range