fix: cast expert_offsets to int32 for CuTeDSL kernel
CuTeDSL's grouped GEMM uses int32 for expert offsets internally. Our cumsum produced int64, causing a type mismatch inside a dynamic if-branch (prev_off changes from Int32 to Int64). Also cast tokens_per_expert to int32 before cumsum.
This commit is contained in:
@@ -126,8 +126,8 @@ class CuTeDSLMoERunner:
|
||||
# Build expert_offsets: cumulative count of tokens per expert
|
||||
# Count how many slots each expert gets
|
||||
expert_id_range = torch.arange(num_experts, device=device)
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0)
|
||||
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int64, device=device)
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
||||
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=device)
|
||||
expert_offsets[1:] = tokens_per_expert.cumsum(0)
|
||||
if expert_offsets[-1] == 0:
|
||||
return torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
|
||||
Reference in New Issue
Block a user