fix: allocate token_indices on CPU, move to GPU AFTER JIT compilation
CuTeDSL's cute.compile corrupts GPU memory during JIT compilation. Tensors allocated on GPU before/during compilation get zeroed. Fix: create token_indices on CPU, then .to(device) after JIT is done.
This commit is contained in:
@@ -109,22 +109,7 @@ class CuTeDSLMoERunner:
|
||||
if self._l1_mat_b is not None:
|
||||
return
|
||||
|
||||
# Allocate token indices FIRST, before any CUDA JIT compilation
|
||||
# (CuTeDSL's cute.compile can corrupt GPU memory state during JIT)
|
||||
# Use .clone() to ensure a fresh allocation and cuda.synchronize() to
|
||||
# ensure the kernel completes before any JIT compilation starts
|
||||
if self._token_indices is None:
|
||||
self._token_indices = torch.arange(
|
||||
self.max_num_tokens, device=self.device, dtype=torch.int32
|
||||
).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).clone()
|
||||
torch.cuda.synchronize()
|
||||
if self._expert_id_range is None:
|
||||
self._expert_id_range = torch.arange(self.num_experts, device=self.device)
|
||||
if self._expert_offsets_buf is None:
|
||||
self._expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation)
|
||||
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
|
||||
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
||||
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
|
||||
@@ -137,6 +122,19 @@ class CuTeDSLMoERunner:
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
# Allocate buffers AFTER JIT compilation
|
||||
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
|
||||
# tensors allocated before/during compilation may be zeroed)
|
||||
self._token_indices = torch.arange(
|
||||
self.max_num_tokens, dtype=torch.int32
|
||||
).unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1).to(self.device)
|
||||
self._expert_id_range = torch.arange(
|
||||
self.num_experts, dtype=torch.int32
|
||||
).to(self.device)
|
||||
self._expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._allocate_buffers()
|
||||
|
||||
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):
|
||||
|
||||
Reference in New Issue
Block a user