[Bugfix] Mamba cache Cuda Graph padding (#6214)

This commit is contained in:
tomeras91
2024-07-08 21:25:51 +03:00
committed by GitHub
parent 185ad31f37
commit ddc369fba1
2 changed files with 30 additions and 2 deletions

View File

@@ -788,12 +788,12 @@ class JambaForCausalLM(nn.Module):
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
batch_size = len(request_ids_to_seq_ids)
cg_batch_size = input_buffers['input_ids'].shape[0]
(
current_mamba_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size)
cg_batch_size)
self.current_indices = indices
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)