[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from vllm.v1.attention.backend import (
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
@@ -158,6 +159,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
query_start_loc_cpu = m.query_start_loc_cpu
|
||||
context_lens_tensor = m.compute_num_computed_tokens()
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
block_table_tensor = mamba_get_block_table_tensor(
|
||||
m.block_table_tensor,
|
||||
m.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)
|
||||
|
||||
spec_sequence_masks_cpu: torch.Tensor | None = None
|
||||
if (
|
||||
@@ -189,7 +196,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_token_indx = None
|
||||
non_spec_token_indx = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
non_spec_state_indices_tensor = block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc_cpu = query_start_loc_cpu
|
||||
@@ -221,7 +228,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
@@ -235,10 +242,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_token_indx = index[:num_non_spec_tokens]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_state_indices_tensor = block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[
|
||||
non_spec_state_indices_tensor = block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user