[DCP] Support dcp kv_cache interleave size > 1 (#26696)
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Signed-off-by: Qiu <qiuchunshuo@huawei.com> Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -22,6 +22,7 @@ class BlockTable:
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_block_size: int,
|
||||
dcp_kv_cache_interleave_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -86,6 +87,7 @@ class BlockTable:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
@@ -144,9 +146,19 @@ class BlockTable:
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
|
||||
mask = (
|
||||
virtual_block_offsets
|
||||
// self.dcp_kv_cache_interleave_size
|
||||
% self.dcp_world_size
|
||||
== self.dcp_rank
|
||||
)
|
||||
# Calculate local block_offsets
|
||||
block_offsets = virtual_block_offsets // self.dcp_world_size
|
||||
block_offsets = (
|
||||
virtual_block_offsets
|
||||
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
|
||||
* self.dcp_kv_cache_interleave_size
|
||||
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
|
||||
)
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
@@ -234,6 +246,7 @@ class MultiGroupBlockTable:
|
||||
block_sizes: list[int],
|
||||
kernel_block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
dcp_kv_cache_interleave_size: int = 1,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
@@ -263,6 +276,7 @@ class MultiGroupBlockTable:
|
||||
pin_memory,
|
||||
device,
|
||||
kernel_block_size,
|
||||
dcp_kv_cache_interleave_size,
|
||||
)
|
||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user