perf: replace expert counting O(n*E) comparison with torch.bincount O(n)

Bug #5 fix: (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0)
materializes a (num_slots × num_experts) bool tensor every forward — 48K × 384 = 18M
elements. torch.bincount(sorted_ids, minlength=num_experts) gives the same result
in O(n) with no intermediate allocation. ~200× less work.

Also removes the now-unused _expert_id_range buffer.
This commit is contained in:
2026-05-20 02:17:23 +00:00
parent 4882d8553c
commit 84a2f6d441

View File

@@ -393,7 +393,7 @@ class CuTeDSLMoERunner:
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
expert_id_range = self._expert_id_range
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int()
expert_offsets = self._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
@@ -496,7 +496,7 @@ class CuTeDSLMoERunner:
# Expert offsets (real token counts)
expert_id_range = self._expert_id_range
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int()
expert_offsets = self._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)