[Bugfix] Fix mamba2 prefill chunking (#23279)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
tomeras91
2025-09-08 14:42:41 +03:00
committed by GitHub
parent 5e537f45b4
commit e041314184
5 changed files with 349 additions and 35 deletions

View File

@@ -16,9 +16,58 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
from vllm.v1.kv_cache_interface import AttentionSpec
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
def _query_start_loc_to_chunk_indices_offsets(
query_start_loc: torch.Tensor, chunk_size: int,
total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
lengths, shape (num_seqs + 1,).
The first element should be 0. Each entry represents the starting
index of a sequence in the flattened token array.
chunk_size (int): The size of each physical mamba chunk
(number of tokens per chunk).
total_seqlens (int): The total number of tokens in the batch.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- chunk_indices (torch.Tensor): 1D tensor of indices
indicating the physical chunk for each logical chunk.
- chunk_offsets (torch.Tensor): 1D tensor of offsets
indicating the starting index of each logical chunk within
its physical chunk.
This function computes the chunk indices and offsets for the given
query_start_loc and chunk_size. Both are tensors of integers with length N,
where N is the number of logical (pseudo) chunks.
A logical chunk is a sequence of tokens that are all part of the same
sequence and are all in the same physical mamba chunk.
In other words, a logical chunk changes every time we cross a sequence
boundary or a physical mamba chunk boundary.
Logical chunks are needed to handle batched requests with initial states
(see _state_passing_fwd and _chunk_scan_fwd).
The chunk_indices tensor contains the index of the physical chunk for each
logical chunk.
The chunk_offsets tensor contains the offset (AKA starting index) of the
logical chunk in the physical chunk.
Example:
query_start_loc = [0, 5, 10]
chunk_size = 8
total_seqlens = 10
-> chunk_indices = [0, 0, 1]
-> chunk_offsets = [0, 5, 0]
In this example, we have 2 sequences, each with 5 tokens. The physical
chunk size is 8 tokens.
We have three logical chunks:
- the first logical chunk starts at token 0 in the first physical chunk
and contains all 5 tokens from the first sequence
- the second logical chunk starts at token 5 in the first physical chunk
and contains first 3 tokens from the second sequence
- the third logical chunk starts at token 0 in the second physical chunk
and contains the remaining 2 tokens from the second sequence
"""
cu_seqlens = query_start_loc[1:] # remove prepended 0