[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Harry Huang
2026-01-24 01:56:48 +08:00
committed by GitHub
parent fec9da0af4
commit 5206e5e28c
42 changed files with 1774 additions and 128 deletions

View File

@@ -16,6 +16,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata,
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -158,6 +159,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
query_start_loc_cpu = m.query_start_loc_cpu
context_lens_tensor = m.compute_num_computed_tokens()
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
block_table_tensor = mamba_get_block_table_tensor(
m.block_table_tensor,
m.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)
spec_sequence_masks_cpu: torch.Tensor | None = None
if (
@@ -189,7 +196,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_token_indx = None
non_spec_token_indx = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
non_spec_state_indices_tensor = block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
non_spec_query_start_loc_cpu = query_start_loc_cpu
@@ -221,7 +228,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_token_indx = torch.empty(
0, dtype=torch.int32, device=query_start_loc.device
)
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
@@ -235,10 +242,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]
spec_state_indices_tensor = m.block_table_tensor[
spec_state_indices_tensor = block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = m.block_table_tensor[
non_spec_state_indices_tensor = block_table_tensor[
~spec_sequence_masks, 0
]

View File

@@ -11,7 +11,10 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
from vllm.v1.attention.backends.utils import (
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -61,7 +64,12 @@ class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMet
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
state_indices_tensor = mamba_get_block_table_tensor(
common_attn_metadata.block_table_tensor,
common_attn_metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(

View File

@@ -18,6 +18,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata,
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -41,11 +42,15 @@ class BaseMambaAttentionMetadata:
state_indices_tensor: torch.Tensor
# The following tensors are only used for prefix caching and are None if disabled
# The following tensors are only used for prefix caching in all mode and
# are None if disabled
block_idx_last_scheduled_token: torch.Tensor | None
block_idx_first_scheduled_token_p: torch.Tensor | None
block_idx_last_computed_token: torch.Tensor | None
# The following tensor is only used for prefix caching in align mode
seq_lens: torch.Tensor
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
@@ -78,7 +83,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
self.compilation_config.max_cudagraph_capture_size,
)
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
self.state_indices_tensor = torch.empty(
(
self.decode_cudagraph_max_bs,
@@ -198,7 +203,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
# Return a tensor of shape (#requests, #max blocks)
@@ -214,7 +219,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
state_indices_tensor = mamba_get_block_table_tensor(
common_attn_metadata.block_table_tensor,
common_attn_metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)[:, 0]
if num_prefills > 0:
if num_computed_tokens is None:
@@ -239,7 +249,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
)
)
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
@@ -258,7 +268,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
@@ -286,6 +296,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
num_reqs=num_reqs,
seq_lens=common_attn_metadata.seq_lens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
@@ -298,8 +309,16 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
slot_mapping: torch.Tensor,
) -> M:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
state_indices_t = mamba_get_block_table_tensor(
blk_table,
metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)
if self.vllm_config.cache_config.mamba_cache_mode in ("none", "align"):
# Only needs the block that saves the running state
state_indices_t = state_indices_t[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer

View File

@@ -17,6 +17,7 @@ from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -854,3 +855,40 @@ def extend_all_queries_by_1(
slot_mapping=new_slot_mapping,
)
return new_cad
def mamba_get_block_table_tensor(
block_table: torch.Tensor,
seq_lens: torch.Tensor,
kv_cache_spec: KVCacheSpec,
mamba_cache_mode: str,
) -> torch.Tensor:
"""
Get the block table tensor for mamba kernels from the input
common_attn_metadata.block_table_tensor given different mamba cache modes.
- "all": input (#requests, cdiv(max_model_len, block_size));
output (#requests, cdiv(max_model_len, block_size)).
- "none": input (#requests, 1 + num_speculative_blocks);
output (#requests, 1 + num_speculative_blocks).
- "align": input (#requests, cdiv(max_model_len, block_size));
output (#requests, 1 + num_speculative_blocks), which are the last
1 + num_speculative_blocks of each request.
"""
if mamba_cache_mode in ("all", "none"):
return block_table
else:
assert isinstance(kv_cache_spec, MambaSpec)
# NOTE: For 0-length requests in CUDA graph, use a start_index of 0
# to handle the invalid block table.
start_indices = torch.clamp(
(seq_lens - 1) // kv_cache_spec.block_size,
min=0,
)
offsets = torch.arange(
1 + kv_cache_spec.num_speculative_blocks, device=block_table.device
)
indices_to_gather = start_indices.unsqueeze(1) + offsets
return torch.gather(block_table, 1, indices_to_gather)