[Bugfix][Nixl] Fix kernel physical<>logical block_size issue (#28677)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -98,7 +98,9 @@ class BlockTable:
|
||||
return
|
||||
|
||||
if self.use_hybrid_blocks:
|
||||
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
|
||||
block_ids = self.map_to_kernel_blocks(
|
||||
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
|
||||
)
|
||||
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
@@ -188,7 +190,12 @@ class BlockTable:
|
||||
self.block_table.gpu.fill_(0)
|
||||
self.block_table.cpu.fill_(0)
|
||||
|
||||
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
|
||||
@staticmethod
|
||||
def map_to_kernel_blocks(
|
||||
kv_manager_block_ids: np.ndarray,
|
||||
blocks_per_kv_block: int,
|
||||
kernel_block_arange: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Convert kv_manager_block_id IDs to kernel block IDs.
|
||||
|
||||
Example:
|
||||
@@ -203,12 +210,12 @@ class BlockTable:
|
||||
# kv_manager_block_id 1 → kernel block id [2, 3]
|
||||
# kv_manager_block_id 2 → kernel block id [4, 5]
|
||||
"""
|
||||
if not self.use_hybrid_blocks:
|
||||
if blocks_per_kv_block == 1:
|
||||
return kv_manager_block_ids
|
||||
|
||||
kernel_block_ids = (
|
||||
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
|
||||
+ self._kernel_block_arange
|
||||
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
|
||||
+ kernel_block_arange
|
||||
)
|
||||
|
||||
return kernel_block_ids.reshape(-1)
|
||||
|
||||
Reference in New Issue
Block a user