debug: assert token indices are correct after allocation
This commit is contained in:
@@ -84,7 +84,11 @@ class CuTeDSLMoERunner:
|
||||
# Slot -> token mapping: [0,0,...,0, 1,1,...,1, ...] (top_k repeats)
|
||||
self._token_indices = torch.arange(
|
||||
self.max_num_tokens, device=self.device
|
||||
).unsqueeze(1).expand(-1, self.top_k).reshape(-1)
|
||||
).unsqueeze(1).expand(-1, self.top_k).reshape(-1).clone()
|
||||
|
||||
# Debug: verify token indices are correct
|
||||
assert self._token_indices[:8].tolist() == [0, 0, 1, 1, 2, 2, 3, 3], \
|
||||
f"Token indices corrupted: {self._token_indices[:8].tolist()}"
|
||||
|
||||
self._expert_id_range = torch.arange(self.num_experts, device=self.device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user