[Bugfix] Mamba cache Cuda Graph padding (#6214)
This commit is contained in:
@@ -788,12 +788,12 @@ class JambaForCausalLM(nn.Module):
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
batch_size = len(request_ids_to_seq_ids)
|
||||
cg_batch_size = input_buffers['input_ids'].shape[0]
|
||||
(
|
||||
current_mamba_cache,
|
||||
indices,
|
||||
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
batch_size)
|
||||
cg_batch_size)
|
||||
self.current_indices = indices
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
self._release_mamba_cache(finished_requests_ids)
|
||||
|
||||
Reference in New Issue
Block a user