Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -14,7 +14,6 @@ logger = init_logger(__name__)
class BlockTable:
def __init__(
self,
block_size: int,
@@ -31,13 +30,14 @@ class BlockTable:
self.pin_memory = pin_memory
self.device = device
self.block_table = self._make_buffer(max_num_reqs,
max_num_blocks_per_req,
dtype=torch.int32)
self.block_table = self._make_buffer(
max_num_reqs, max_num_blocks_per_req, dtype=torch.int32
)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping = self._make_buffer(self.max_num_batched_tokens,
dtype=torch.int64)
self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens, dtype=torch.int64
)
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
@@ -56,7 +56,7 @@ class BlockTable:
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table.np[row_idx, start:start + num_blocks] = block_ids
self.block_table.np[row_idx, start : start + num_blocks] = block_ids
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
@@ -73,8 +73,9 @@ class BlockTable:
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray
) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
@@ -89,8 +90,10 @@ class BlockTable:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // virtual_block_size)
block_table_indices = (
req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
@@ -101,16 +104,20 @@ class BlockTable:
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
mask, slot_mapping, -1)
self.slot_mapping.np[: req_indices.shape[0]] = np.where(
mask, slot_mapping, -1
)
else:
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // self.block_size)
block_table_indices = (
req_indices * self.max_num_blocks_per_req + positions // self.block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping.np[:req_indices.shape[0]])
np.add(
block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping.np[: req_indices.shape[0]],
)
def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs)
@@ -134,25 +141,27 @@ class BlockTable:
"""Returns the numpy array of the block table."""
return self.block_table.np
def _make_buffer(self, *size: Union[int, torch.SymInt],
dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
def _make_buffer(
self, *size: Union[int, torch.SymInt], dtype: torch.dtype
) -> CpuGpuBuffer:
return CpuGpuBuffer(
*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
)
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0) -> None:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
@@ -165,14 +174,20 @@ class MultiGroupBlockTable:
self.block_tables = [
BlockTable(
block_size, max_num_reqs,
max(cdiv(max_model_len, block_size * dcp_world_size),
1 + num_speculative_tokens), max_num_batched_tokens,
pin_memory, device) for block_size in block_sizes
block_size,
max_num_reqs,
max(
cdiv(max_model_len, block_size * dcp_world_size),
1 + num_speculative_tokens,
),
max_num_batched_tokens,
pin_memory,
device,
)
for block_size in block_sizes
]
def append_row(self, block_ids: tuple[list[int], ...],
row_idx: int) -> None:
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
@@ -188,8 +203,9 @@ class MultiGroupBlockTable:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray
) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions)