[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from vllm.v1.attention.backend import (
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
@@ -158,6 +159,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
query_start_loc_cpu = m.query_start_loc_cpu
|
||||
context_lens_tensor = m.compute_num_computed_tokens()
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
block_table_tensor = mamba_get_block_table_tensor(
|
||||
m.block_table_tensor,
|
||||
m.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)
|
||||
|
||||
spec_sequence_masks_cpu: torch.Tensor | None = None
|
||||
if (
|
||||
@@ -189,7 +196,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_token_indx = None
|
||||
non_spec_token_indx = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
non_spec_state_indices_tensor = block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc_cpu = query_start_loc_cpu
|
||||
@@ -221,7 +228,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
@@ -235,10 +242,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_token_indx = index[:num_non_spec_tokens]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_state_indices_tensor = block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[
|
||||
non_spec_state_indices_tensor = block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
]
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
@@ -61,7 +64,12 @@ class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMet
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
state_indices_tensor = mamba_get_block_table_tensor(
|
||||
common_attn_metadata.block_table_tensor,
|
||||
common_attn_metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
|
||||
@@ -18,6 +18,7 @@ from vllm.v1.attention.backend import (
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
@@ -41,11 +42,15 @@ class BaseMambaAttentionMetadata:
|
||||
|
||||
state_indices_tensor: torch.Tensor
|
||||
|
||||
# The following tensors are only used for prefix caching and are None if disabled
|
||||
# The following tensors are only used for prefix caching in all mode and
|
||||
# are None if disabled
|
||||
block_idx_last_scheduled_token: torch.Tensor | None
|
||||
block_idx_first_scheduled_token_p: torch.Tensor | None
|
||||
block_idx_last_computed_token: torch.Tensor | None
|
||||
|
||||
# The following tensor is only used for prefix caching in align mode
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
@@ -78,7 +83,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(
|
||||
self.decode_cudagraph_max_bs,
|
||||
@@ -198,7 +203,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
@@ -214,7 +219,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
state_indices_tensor = mamba_get_block_table_tensor(
|
||||
common_attn_metadata.block_table_tensor,
|
||||
common_attn_metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)[:, 0]
|
||||
|
||||
if num_prefills > 0:
|
||||
if num_computed_tokens is None:
|
||||
@@ -239,7 +249,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
)
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
@@ -258,7 +268,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
@@ -286,6 +296,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
num_reqs=num_reqs,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
@@ -298,8 +309,16 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> M:
|
||||
new_metadata = copy.copy(metadata)
|
||||
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
|
||||
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
|
||||
state_indices_t = mamba_get_block_table_tensor(
|
||||
blk_table,
|
||||
metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)
|
||||
if self.vllm_config.cache_config.mamba_cache_mode in ("none", "align"):
|
||||
# Only needs the block that saves the running state
|
||||
state_indices_t = state_indices_t[:, 0]
|
||||
|
||||
num_reqs = blk_table.shape[0]
|
||||
|
||||
# For CUDA graphs, copy to persistent buffer
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing_extensions import runtime_checkable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -854,3 +855,40 @@ def extend_all_queries_by_1(
|
||||
slot_mapping=new_slot_mapping,
|
||||
)
|
||||
return new_cad
|
||||
|
||||
|
||||
def mamba_get_block_table_tensor(
|
||||
block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
mamba_cache_mode: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get the block table tensor for mamba kernels from the input
|
||||
common_attn_metadata.block_table_tensor given different mamba cache modes.
|
||||
|
||||
- "all": input (#requests, cdiv(max_model_len, block_size));
|
||||
output (#requests, cdiv(max_model_len, block_size)).
|
||||
|
||||
- "none": input (#requests, 1 + num_speculative_blocks);
|
||||
output (#requests, 1 + num_speculative_blocks).
|
||||
|
||||
- "align": input (#requests, cdiv(max_model_len, block_size));
|
||||
output (#requests, 1 + num_speculative_blocks), which are the last
|
||||
1 + num_speculative_blocks of each request.
|
||||
"""
|
||||
if mamba_cache_mode in ("all", "none"):
|
||||
return block_table
|
||||
else:
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
# NOTE: For 0-length requests in CUDA graph, use a start_index of 0
|
||||
# to handle the invalid block table.
|
||||
start_indices = torch.clamp(
|
||||
(seq_lens - 1) // kv_cache_spec.block_size,
|
||||
min=0,
|
||||
)
|
||||
offsets = torch.arange(
|
||||
1 + kv_cache_spec.num_speculative_blocks, device=block_table.device
|
||||
)
|
||||
indices_to_gather = start_indices.unsqueeze(1) + offsets
|
||||
return torch.gather(block_table, 1, indices_to_gather)
|
||||
|
||||
Reference in New Issue
Block a user