[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Harry Huang
2026-01-24 01:56:48 +08:00
committed by GitHub
parent fec9da0af4
commit 5206e5e28c
42 changed files with 1774 additions and 128 deletions

View File

@@ -8,6 +8,7 @@ from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
@@ -261,47 +262,45 @@ class MultiGroupBlockTable:
device: torch.device,
block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
max_num_blocks: list[int] | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
pcp_world_size = get_pcp_group().world_size
except AssertionError:
# PCP might not be initialized in testing
pcp_world_size = 1
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
if len(kernel_block_sizes) != len(block_sizes):
raise ValueError(
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
if max_num_blocks is None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
total_cp_world_size = get_total_cp_world_size()
max_num_blocks = [
cdiv(max_model_len, block_size * total_cp_world_size)
for block_size in block_sizes
]
total_cp_world_size = dcp_world_size * pcp_world_size
if len(max_num_blocks) != len(block_sizes):
raise ValueError(
f"max_num_blocks length ({len(max_num_blocks)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max(
cdiv(max_model_len, block_size * total_cp_world_size),
1 + num_speculative_tokens,
),
max_num_blocks_per_req,
max_num_batched_tokens,
pin_memory,
device,
kernel_block_size,
cp_kv_cache_interleave_size,
)
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
for block_size, kernel_block_size, max_num_blocks_per_req in zip(
block_sizes, kernel_block_sizes, max_num_blocks
)
]
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: