fix: keep token_indices on CPU, index with CPU sort_idx
CuTeDSL's cute.compile corrupts GPU memory during JIT compilation. Keeping token_indices on CPU and using sort_idx.cpu() for indexing avoids the corruption. The .to(device) call after indexing moves the result back to GPU for the hidden_states indexing.
This commit is contained in:
@@ -129,17 +129,9 @@ class CuTeDSLMoERunner:
|
||||
# tensors allocated before/during compilation may be zeroed)
|
||||
self._token_indices = torch.arange(
|
||||
self.max_num_tokens, dtype=torch.int32
|
||||
).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).to(self.device)
|
||||
# Verify the tensor was transferred correctly (CuTeDSL JIT can corrupt GPU state)
|
||||
_verify = self._token_indices[:8].cpu().tolist()
|
||||
if _verify != [0, 0, 1, 1, 2, 2, 3, 3]:
|
||||
# Fallback: allocate directly on GPU
|
||||
self._token_indices = torch.zeros(
|
||||
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
|
||||
)
|
||||
for i in range(self.max_num_tokens):
|
||||
self._token_indices[i * self.top_k] = i
|
||||
self._token_indices[i * self.top_k + 1] = i
|
||||
).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
|
||||
# Keep on CPU to avoid CuTeDSL JIT GPU memory corruption
|
||||
# Will be indexed with CPU offsets during slot mapping
|
||||
self._expert_id_range = torch.arange(
|
||||
self.num_experts, dtype=torch.int32
|
||||
).to(self.device)
|
||||
@@ -247,7 +239,7 @@ class CuTeDSLMoERunner:
|
||||
token_indices = self._token_indices[:num_slots]
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx.cpu()].to(device)
|
||||
slot_hidden = hidden_states_sample[sorted_token_ids]
|
||||
|
||||
# Debug: verify slot_hidden
|
||||
@@ -329,7 +321,7 @@ class CuTeDSLMoERunner:
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_weights = flat_weights[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx.cpu()].to(device)
|
||||
|
||||
# Expert offsets (GPU-only, never touches CPU)
|
||||
expert_id_range = self._expert_id_range
|
||||
|
||||
Reference in New Issue
Block a user