[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)
Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@@ -50,6 +51,13 @@ class BlockTable:
|
||||
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
@@ -89,13 +97,36 @@ class BlockTable:
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // self.block_size)
|
||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:req_indices.shape[0]])
|
||||
if self.dcp_world_size > 1:
|
||||
# Note(hc): The DCP implement store kvcache with a interleave
|
||||
# style, the kvcache for the token whose token_idx is i is
|
||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
virtual_block_size = self.block_size * self.dcp_world_size
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // virtual_block_size)
|
||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
|
||||
# Calcuate local block_offsets
|
||||
block_offsets = virtual_block_offsets // self.dcp_world_size
|
||||
# Calcuate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
|
||||
mask, slot_mapping, -1)
|
||||
else:
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // self.block_size)
|
||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:req_indices.shape[0]])
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
@@ -128,9 +159,19 @@ class MultiGroupBlockTable:
|
||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
||||
max_num_batched_tokens: int, pin_memory: bool,
|
||||
device: torch.device, block_sizes: list[int]) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
# must be multiplied by dcp_world_size.
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
|
||||
self.block_tables = [
|
||||
BlockTable(block_size, max_num_reqs, cdiv(max_model_len,
|
||||
block_size),
|
||||
BlockTable(block_size, max_num_reqs,
|
||||
cdiv(max_model_len, block_size * dcp_world_size),
|
||||
max_num_batched_tokens, pin_memory, device)
|
||||
for block_size in block_sizes
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user