diff --git a/cutedsl/runner.py b/cutedsl/runner.py index 20ceb93d..61bca6ae 100644 --- a/cutedsl/runner.py +++ b/cutedsl/runner.py @@ -393,7 +393,7 @@ class CuTeDSLMoERunner: slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs) expert_id_range = self._expert_id_range - tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() + tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int() expert_offsets = self._expert_offsets_buf expert_offsets.zero_() expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0) @@ -496,7 +496,7 @@ class CuTeDSLMoERunner: # Expert offsets (real token counts) expert_id_range = self._expert_id_range - tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() + tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int() expert_offsets = self._expert_offsets_buf expert_offsets.zero_() expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)