fix: cast expert_offsets to int32 for CuTeDSL kernel

CuTeDSL's grouped GEMM uses int32 for expert offsets internally.
Our cumsum produced int64, causing a type mismatch inside a dynamic
if-branch (prev_off changes from Int32 to Int64).

Also cast tokens_per_expert to int32 before cumsum.
This commit is contained in:
2026-05-16 07:15:57 +00:00
parent 4b0a9557f0
commit e0814eb54e

View File

@@ -126,8 +126,8 @@ class CuTeDSLMoERunner:
# Build expert_offsets: cumulative count of tokens per expert
# Count how many slots each expert gets
expert_id_range = torch.arange(num_experts, device=device)
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0)
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int64, device=device)
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=device)
expert_offsets[1:] = tokens_per_expert.cumsum(0)
if expert_offsets[-1] == 0:
return torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)