[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -261,47 +262,45 @@ class MultiGroupBlockTable:
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
kernel_block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
max_num_blocks: list[int] | None = None,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
# must be multiplied by dcp_world_size.
|
||||
try:
|
||||
pcp_world_size = get_pcp_group().world_size
|
||||
except AssertionError:
|
||||
# PCP might not be initialized in testing
|
||||
pcp_world_size = 1
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
|
||||
if len(kernel_block_sizes) != len(block_sizes):
|
||||
raise ValueError(
|
||||
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
|
||||
f"must match block_sizes length ({len(block_sizes)})"
|
||||
)
|
||||
if max_num_blocks is None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
# must be multiplied by dcp_world_size.
|
||||
total_cp_world_size = get_total_cp_world_size()
|
||||
max_num_blocks = [
|
||||
cdiv(max_model_len, block_size * total_cp_world_size)
|
||||
for block_size in block_sizes
|
||||
]
|
||||
|
||||
total_cp_world_size = dcp_world_size * pcp_world_size
|
||||
if len(max_num_blocks) != len(block_sizes):
|
||||
raise ValueError(
|
||||
f"max_num_blocks length ({len(max_num_blocks)}) "
|
||||
f"must match block_sizes length ({len(block_sizes)})"
|
||||
)
|
||||
|
||||
self.block_tables = [
|
||||
BlockTable(
|
||||
block_size,
|
||||
max_num_reqs,
|
||||
max(
|
||||
cdiv(max_model_len, block_size * total_cp_world_size),
|
||||
1 + num_speculative_tokens,
|
||||
),
|
||||
max_num_blocks_per_req,
|
||||
max_num_batched_tokens,
|
||||
pin_memory,
|
||||
device,
|
||||
kernel_block_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||
for block_size, kernel_block_size, max_num_blocks_per_req in zip(
|
||||
block_sizes, kernel_block_sizes, max_num_blocks
|
||||
)
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||
|
||||
Reference in New Issue
Block a user