[Bugfix] Fix mamba2 prefill chunking (#23279)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -16,9 +16,58 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
total_seqlens: int):
|
||||
def _query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc: torch.Tensor, chunk_size: int,
|
||||
total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
|
||||
lengths, shape (num_seqs + 1,).
|
||||
The first element should be 0. Each entry represents the starting
|
||||
index of a sequence in the flattened token array.
|
||||
chunk_size (int): The size of each physical mamba chunk
|
||||
(number of tokens per chunk).
|
||||
total_seqlens (int): The total number of tokens in the batch.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- chunk_indices (torch.Tensor): 1D tensor of indices
|
||||
indicating the physical chunk for each logical chunk.
|
||||
- chunk_offsets (torch.Tensor): 1D tensor of offsets
|
||||
indicating the starting index of each logical chunk within
|
||||
its physical chunk.
|
||||
|
||||
This function computes the chunk indices and offsets for the given
|
||||
query_start_loc and chunk_size. Both are tensors of integers with length N,
|
||||
where N is the number of logical (pseudo) chunks.
|
||||
A logical chunk is a sequence of tokens that are all part of the same
|
||||
sequence and are all in the same physical mamba chunk.
|
||||
In other words, a logical chunk changes every time we cross a sequence
|
||||
boundary or a physical mamba chunk boundary.
|
||||
Logical chunks are needed to handle batched requests with initial states
|
||||
(see _state_passing_fwd and _chunk_scan_fwd).
|
||||
The chunk_indices tensor contains the index of the physical chunk for each
|
||||
logical chunk.
|
||||
The chunk_offsets tensor contains the offset (AKA starting index) of the
|
||||
logical chunk in the physical chunk.
|
||||
|
||||
Example:
|
||||
query_start_loc = [0, 5, 10]
|
||||
chunk_size = 8
|
||||
total_seqlens = 10
|
||||
-> chunk_indices = [0, 0, 1]
|
||||
-> chunk_offsets = [0, 5, 0]
|
||||
|
||||
In this example, we have 2 sequences, each with 5 tokens. The physical
|
||||
chunk size is 8 tokens.
|
||||
We have three logical chunks:
|
||||
- the first logical chunk starts at token 0 in the first physical chunk
|
||||
and contains all 5 tokens from the first sequence
|
||||
- the second logical chunk starts at token 5 in the first physical chunk
|
||||
and contains first 3 tokens from the second sequence
|
||||
- the third logical chunk starts at token 0 in the second physical chunk
|
||||
and contains the remaining 2 tokens from the second sequence
|
||||
"""
|
||||
|
||||
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user