[Kernel][Misc] register ops to prevent graph breaks (#6917)
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -733,7 +733,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
indices_for_current_run: List[int]):
|
||||
# move out all of the occupied but currently not running blocks
|
||||
# outside of the first n blocks
|
||||
destination_indices = set([range(batch_size)])
|
||||
destination_indices = range(batch_size)
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
for destination_index in destination_indices:
|
||||
if destination_index in self._get_all_occupied_indices() and \
|
||||
|
||||
Reference in New Issue
Block a user