[Kernel][Misc] register ops to prevent graph breaks (#6917)

Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
bnellnm
2024-09-11 15:52:19 -04:00
committed by GitHub
parent 7015417fd4
commit 73202dbe77
22 changed files with 528 additions and 102 deletions

View File

@@ -733,7 +733,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = set([range(batch_size)])
destination_indices = range(batch_size)
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \