[Hybrid]: Decouple Kernel Block Size from KV Page Size (#24486)
Signed-off-by: lizhiyuan <uniartisan2017@gmail.com> Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
@@ -22,22 +22,64 @@ class BlockTable:
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_block_size: int,
|
||||
):
|
||||
self.block_size = block_size
|
||||
"""
|
||||
Args:
|
||||
block_size: Block size used for KV cache memory allocation
|
||||
max_num_reqs: Maximum number of concurrent requests supported.
|
||||
max_num_blocks_per_req: Maximum number of blocks per request.
|
||||
max_num_batched_tokens: Maximum number of tokens in a batch.
|
||||
pin_memory: Whether to pin memory for faster GPU transfers.
|
||||
device: Target device for the block table.
|
||||
kernel_block_size: The block_size of underlying attention kernel.
|
||||
Will be the same as `block_size` if `block_size` is supported
|
||||
by the attention kernel.
|
||||
"""
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
if kernel_block_size == block_size:
|
||||
# Standard case: allocation and computation use same block size
|
||||
# No block splitting needed, direct mapping
|
||||
self.block_size = block_size
|
||||
self.blocks_per_kv_block = 1
|
||||
self.use_hybrid_blocks = False
|
||||
else:
|
||||
# Hybrid case: allocation block size differs from kernel block size
|
||||
# Memory blocks are subdivided to match kernel requirements
|
||||
# Example: 32-token memory blocks with 16-token kernel blocks
|
||||
# → Each memory block corresponds to 2 kernel blocks
|
||||
if block_size % kernel_block_size != 0:
|
||||
raise ValueError(
|
||||
f"kernel_block_size {kernel_block_size} must divide "
|
||||
f"kv_manager_block_size size {block_size} evenly"
|
||||
)
|
||||
|
||||
self.block_size = kernel_block_size
|
||||
self.blocks_per_kv_block = block_size // kernel_block_size
|
||||
self.use_hybrid_blocks = True
|
||||
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
|
||||
|
||||
self.block_table = self._make_buffer(
|
||||
max_num_reqs, max_num_blocks_per_req, dtype=torch.int32
|
||||
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
|
||||
)
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping = self._make_buffer(
|
||||
self.max_num_batched_tokens, dtype=torch.int64
|
||||
)
|
||||
|
||||
if self.use_hybrid_blocks:
|
||||
self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
|
||||
1, -1
|
||||
)
|
||||
else:
|
||||
self._kernel_block_arange = None
|
||||
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
@@ -53,6 +95,10 @@ class BlockTable:
|
||||
) -> None:
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
if self.use_hybrid_blocks:
|
||||
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
|
||||
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
@@ -94,6 +140,7 @@ class BlockTable:
|
||||
req_indices * self.max_num_blocks_per_req
|
||||
+ positions // virtual_block_size
|
||||
)
|
||||
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
@@ -111,6 +158,7 @@ class BlockTable:
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req + positions // self.block_size
|
||||
)
|
||||
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(
|
||||
@@ -129,6 +177,31 @@ class BlockTable:
|
||||
self.block_table.gpu.fill_(0)
|
||||
self.block_table.cpu.fill_(0)
|
||||
|
||||
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
|
||||
"""Convert kv_manager_block_id IDs to kernel block IDs.
|
||||
|
||||
Example:
|
||||
# kv_manager_block_ids: 32 tokens,
|
||||
# Kernel block size: 16 tokens
|
||||
# blocks_per_kv_block = 2
|
||||
>>> kv_manager_block_ids = np.array([0, 1, 2])
|
||||
>>> Result: [0, 1, 2, 3, 4, 5]
|
||||
|
||||
# Each kv_manager_block_id maps to 2 kernel block id:
|
||||
# kv_manager_block_id 0 → kernel block id [0, 1]
|
||||
# kv_manager_block_id 1 → kernel block id [2, 3]
|
||||
# kv_manager_block_id 2 → kernel block id [4, 5]
|
||||
"""
|
||||
if not self.use_hybrid_blocks:
|
||||
return kv_manager_block_ids
|
||||
|
||||
kernel_block_ids = (
|
||||
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
|
||||
+ self._kernel_block_arange
|
||||
)
|
||||
|
||||
return kernel_block_ids.reshape(-1)
|
||||
|
||||
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
|
||||
"""Returns the device tensor of the block table."""
|
||||
return self.block_table.gpu[:num_reqs]
|
||||
@@ -160,6 +233,7 @@ class MultiGroupBlockTable:
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
kernel_block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
@@ -172,6 +246,12 @@ class MultiGroupBlockTable:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
|
||||
if len(kernel_block_sizes) != len(block_sizes):
|
||||
raise ValueError(
|
||||
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
|
||||
f"must match block_sizes length ({len(block_sizes)})"
|
||||
)
|
||||
|
||||
self.block_tables = [
|
||||
BlockTable(
|
||||
block_size,
|
||||
@@ -183,8 +263,9 @@ class MultiGroupBlockTable:
|
||||
max_num_batched_tokens,
|
||||
pin_memory,
|
||||
device,
|
||||
kernel_block_size,
|
||||
)
|
||||
for block_size in block_sizes
|
||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||
|
||||
Reference in New Issue
Block a user