fix: allocate token_indices on CPU, move to GPU AFTER JIT compilation

CuTeDSL's cute.compile corrupts GPU memory during JIT compilation.
Tensors allocated on GPU before/during compilation get zeroed.
Fix: create token_indices on CPU, then .to(device) after JIT is done.
This commit is contained in:
2026-05-17 08:22:51 +00:00
parent 0e7b06b55c
commit 7073daaffa

View File

@@ -109,22 +109,7 @@ class CuTeDSLMoERunner:
if self._l1_mat_b is not None:
return
# Allocate token indices FIRST, before any CUDA JIT compilation
# (CuTeDSL's cute.compile can corrupt GPU memory state during JIT)
# Use .clone() to ensure a fresh allocation and cuda.synchronize() to
# ensure the kernel completes before any JIT compilation starts
if self._token_indices is None:
self._token_indices = torch.arange(
self.max_num_tokens, device=self.device, dtype=torch.int32
).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).clone()
torch.cuda.synchronize()
if self._expert_id_range is None:
self._expert_id_range = torch.arange(self.num_experts, device=self.device)
if self._expert_offsets_buf is None:
self._expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
# Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation)
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
@@ -137,6 +122,19 @@ class CuTeDSLMoERunner:
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
# Allocate buffers AFTER JIT compilation
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
# tensors allocated before/during compilation may be zeroed)
self._token_indices = torch.arange(
self.max_num_tokens, dtype=torch.int32
).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).to(self.device)
self._expert_id_range = torch.arange(
self.num_experts, dtype=torch.int32
).to(self.device)
self._expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
self._allocate_buffers()
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):