debug: assert token indices are correct after allocation

This commit is contained in:
2026-05-17 08:16:09 +00:00
parent c0d016a472
commit da02a5dc11

View File

@@ -84,7 +84,11 @@ class CuTeDSLMoERunner:
# Slot -> token mapping: [0,0,...,0, 1,1,...,1, ...] (top_k repeats)
self._token_indices = torch.arange(
self.max_num_tokens, device=self.device
).unsqueeze(1).expand(-1, self.top_k).reshape(-1)
).unsqueeze(1).expand(-1, self.top_k).reshape(-1).clone()
# Debug: verify token indices are correct
assert self._token_indices[:8].tolist() == [0, 0, 1, 1, 2, 2, 3, 3], \
f"Token indices corrupted: {self._token_indices[:8].tolist()}"
self._expert_id_range = torch.arange(self.num_experts, device=self.device)