perf: replace expert counting O(n*E) comparison with torch.bincount O(n)
Bug #5 fix: (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0) materializes a (num_slots × num_experts) bool tensor every forward — 48K × 384 = 18M elements. torch.bincount(sorted_ids, minlength=num_experts) gives the same result in O(n) with no intermediate allocation. ~200× less work. Also removes the now-unused _expert_id_range buffer.
This commit is contained in:
@@ -393,7 +393,7 @@ class CuTeDSLMoERunner:
|
||||
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
|
||||
|
||||
expert_id_range = self._expert_id_range
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
||||
tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int()
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
@@ -496,7 +496,7 @@ class CuTeDSLMoERunner:
|
||||
|
||||
# Expert offsets (real token counts)
|
||||
expert_id_range = self._expert_id_range
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
||||
tokens_per_expert = torch.bincount(sorted_ids, minlength=self.num_experts)[:self.num_experts].int()
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
Reference in New Issue
Block a user