Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockTable:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
@@ -31,13 +30,14 @@ class BlockTable:
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
self.block_table = self._make_buffer(max_num_reqs,
|
||||
max_num_blocks_per_req,
|
||||
dtype=torch.int32)
|
||||
self.block_table = self._make_buffer(
|
||||
max_num_reqs, max_num_blocks_per_req, dtype=torch.int32
|
||||
)
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping = self._make_buffer(self.max_num_batched_tokens,
|
||||
dtype=torch.int64)
|
||||
self.slot_mapping = self._make_buffer(
|
||||
self.max_num_batched_tokens, dtype=torch.int64
|
||||
)
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
@@ -56,7 +56,7 @@ class BlockTable:
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
self.block_table.np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.block_table.np[row_idx, start : start + num_blocks] = block_ids
|
||||
|
||||
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
@@ -73,8 +73,9 @@ class BlockTable:
|
||||
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
|
||||
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
|
||||
|
||||
def compute_slot_mapping(self, req_indices: np.ndarray,
|
||||
positions: np.ndarray) -> None:
|
||||
def compute_slot_mapping(
|
||||
self, req_indices: np.ndarray, positions: np.ndarray
|
||||
) -> None:
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
@@ -89,8 +90,10 @@ class BlockTable:
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
virtual_block_size = self.block_size * self.dcp_world_size
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // virtual_block_size)
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req
|
||||
+ positions // virtual_block_size
|
||||
)
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
@@ -101,16 +104,20 @@ class BlockTable:
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
|
||||
mask, slot_mapping, -1)
|
||||
self.slot_mapping.np[: req_indices.shape[0]] = np.where(
|
||||
mask, slot_mapping, -1
|
||||
)
|
||||
else:
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // self.block_size)
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req + positions // self.block_size
|
||||
)
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping.np[:req_indices.shape[0]])
|
||||
np.add(
|
||||
block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping.np[: req_indices.shape[0]],
|
||||
)
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
self.block_table.copy_to_gpu(num_reqs)
|
||||
@@ -134,25 +141,27 @@ class BlockTable:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table.np
|
||||
|
||||
def _make_buffer(self, *size: Union[int, torch.SymInt],
|
||||
dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
def _make_buffer(
|
||||
self, *size: Union[int, torch.SymInt], dtype: torch.dtype
|
||||
) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
|
||||
)
|
||||
|
||||
|
||||
class MultiGroupBlockTable:
|
||||
"""The BlockTables for each KV cache group."""
|
||||
|
||||
def __init__(self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
@@ -165,14 +174,20 @@ class MultiGroupBlockTable:
|
||||
|
||||
self.block_tables = [
|
||||
BlockTable(
|
||||
block_size, max_num_reqs,
|
||||
max(cdiv(max_model_len, block_size * dcp_world_size),
|
||||
1 + num_speculative_tokens), max_num_batched_tokens,
|
||||
pin_memory, device) for block_size in block_sizes
|
||||
block_size,
|
||||
max_num_reqs,
|
||||
max(
|
||||
cdiv(max_model_len, block_size * dcp_world_size),
|
||||
1 + num_speculative_tokens,
|
||||
),
|
||||
max_num_batched_tokens,
|
||||
pin_memory,
|
||||
device,
|
||||
)
|
||||
for block_size in block_sizes
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: tuple[list[int], ...],
|
||||
row_idx: int) -> None:
|
||||
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.append_row(block_ids[i], row_idx)
|
||||
|
||||
@@ -188,8 +203,9 @@ class MultiGroupBlockTable:
|
||||
for block_table in self.block_tables:
|
||||
block_table.swap_row(src, tgt)
|
||||
|
||||
def compute_slot_mapping(self, req_indices: np.ndarray,
|
||||
positions: np.ndarray) -> None:
|
||||
def compute_slot_mapping(
|
||||
self, req_indices: np.ndarray, positions: np.ndarray
|
||||
) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.compute_slot_mapping(req_indices, positions)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user