[Mamba1] - Kernel Level Chunk Alignment for Prefix Caching (#34798)
Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any
|
||||
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
@@ -29,3 +30,31 @@ class Mamba1AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||
):
|
||||
metadata_cls = Mamba1AttentionMetadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
common = self._compute_common_metadata(common_attn_metadata)
|
||||
|
||||
if (
|
||||
common.num_prefills > 0
|
||||
and self.vllm_config.cache_config.mamba_cache_mode == "all"
|
||||
):
|
||||
cu_chunk_seqlen_p, _, last_chunk_indices_p = (
|
||||
self._build_chunk_metadata_tensors(
|
||||
self.kv_cache_spec.block_size,
|
||||
common,
|
||||
common_attn_metadata,
|
||||
)
|
||||
)
|
||||
return replace(
|
||||
common,
|
||||
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||
last_chunk_indices_p=last_chunk_indices_p,
|
||||
)
|
||||
|
||||
return common
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
CommonAttentionMetadata,
|
||||
@@ -105,14 +104,6 @@ class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
|
||||
|
||||
# Chunk-related metadata (only for prefill)
|
||||
seq_idx_p: torch.Tensor | None = None
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
@@ -134,68 +125,6 @@ class Mamba2AttentionMetadataBuilder(
|
||||
)
|
||||
self.chunk_size: int = chunk_size
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba2.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba2 (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % self.chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
||||
- this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@@ -220,41 +149,12 @@ class Mamba2AttentionMetadataBuilder(
|
||||
else False
|
||||
)
|
||||
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = (
|
||||
self._build_chunk_metadata_tensors(
|
||||
self.chunk_size,
|
||||
common,
|
||||
common_attn_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return replace(
|
||||
|
||||
@@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata:
|
||||
# The following tensor is only used for prefix caching in align mode
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
@@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
|
||||
)
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
chunk_size: int,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba models.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def _build_chunk_metadata_tensors(
|
||||
self,
|
||||
chunk_size: int,
|
||||
common: M,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute chunk metadata and return as device tensors.
|
||||
Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p).
|
||||
"""
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
chunk_size,
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
Reference in New Issue
Block a user