Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -50,7 +50,8 @@ class KVCacheSpec:
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert all(spec == specs[0] for spec in specs[1:]), (
"All layers in the same KV cache group must be the same.")
"All layers in the same KV cache group must be the same."
)
return copy.deepcopy(specs[0])
@@ -62,8 +63,13 @@ class AttentionSpec(KVCacheSpec):
@property
def page_size_bytes(self) -> int:
return 2 * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
return (
2
* self.block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(self.dtype)
)
@dataclass(frozen=True)
@@ -82,8 +88,7 @@ class FullAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1:
@@ -99,24 +104,30 @@ class FullAttentionSpec(AttentionSpec):
else:
raise ValueError(
"All attention layers in the same KV cache group must have the "
"same window size.")
"same window size."
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"FullAttentionSpec.")
"All attention layers in the same KV cache group must be FullAttentionSpec."
)
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None)
sliding_window = set(
spec.sliding_window for spec in specs if spec.sliding_window is not None
)
attention_chunk_size = set(
spec.attention_chunk_size
for spec in specs
if spec.attention_chunk_size is not None
)
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge")
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
)
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
@@ -129,12 +140,14 @@ class FullAttentionSpec(AttentionSpec):
for f in fields(AttentionSpec):
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
"All attention layers in the same KV cache group must have "
"the same attention spec.")
assert (
(merged_spec.sliding_window is not None) +
(merged_spec.attention_chunk_size is not None) <= 1
), ("Model with both sliding window layers and chunked local attention "
"layers is not supported.")
"the same attention spec."
)
assert (merged_spec.sliding_window is not None) + (
merged_spec.attention_chunk_size is not None
) <= 1, (
"Model with both sliding window layers and chunked local attention "
"layers is not supported."
)
return merged_spec
@@ -149,18 +162,23 @@ class MLAAttentionSpec(FullAttentionSpec):
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
# for details.
return self.block_size * 656
return self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
return (
self.block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(self.dtype)
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"MLAAttentionSpec.")
"All attention layers in the same KV cache group must be MLAAttentionSpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same "
"quantization method.")
"quantization method."
)
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
@@ -176,15 +194,15 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
# During chunked prefill, we allocate KV cache for at most
# `self.attention_chunk_size` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
max_model_len)
num_tokens = min(
self.attention_chunk_size + max_num_batched_tokens, max_model_len
)
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
@@ -194,18 +212,19 @@ class SlidingWindowSpec(AttentionSpec):
sliding_window: int
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
"DCP not support sliding window."
)
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
# During chunked prefill, we allocate KV cache for the last
# `self.sliding_window-1` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
max_model_len)
num_tokens = min(
self.sliding_window - 1 + max_num_batched_tokens, max_model_len
)
# +1 here because the sliding window may not start from the beginning
# of the block. For example, if the block size is 4 and num_token
@@ -226,7 +245,8 @@ class MambaSpec(KVCacheSpec):
def page_size_bytes(self) -> int:
page_size = sum(
prod(shape) * get_dtype_size(dtype)
for (shape, dtype) in zip(self.shapes, self.dtypes))
for (shape, dtype) in zip(self.shapes, self.dtypes)
)
if self.page_size_padded is not None:
assert self.page_size_padded >= page_size
return self.page_size_padded
@@ -239,7 +259,6 @@ class MambaSpec(KVCacheSpec):
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# Encoder-only layers do not need KV cache
return 0
@@ -254,8 +273,7 @@ class CrossAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# For cross-attention, we need to cache encoder states
# Get encoder length (e.g., 1500 for Whisper).
max_encoder_len = vllm_config.scheduler_config.\
max_num_encoder_input_tokens
max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
@@ -267,18 +285,18 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
sliding window attentions with different window sizes are not the same type
and should not be merged into one UniformTypeKVCacheSpecs.
"""
kv_cache_specs: dict[str, KVCacheSpec]
@property
def page_size_bytes(self) -> int:
return sum(spec.page_size_bytes
for spec in self.kv_cache_specs.values())
return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_num_pages = max(
cdiv(spec.max_memory_usage_bytes(vllm_config),
spec.page_size_bytes)
for spec in self.kv_cache_specs.values())
cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
for spec in self.kv_cache_specs.values()
)
return max_num_pages * self.page_size_bytes
@classmethod
@@ -293,35 +311,38 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
one_spec = next(iter(kv_cache_specs.values()))
if isinstance(one_spec, FullAttentionSpec):
return all(
isinstance(spec, FullAttentionSpec)
for spec in kv_cache_specs.values())
isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, CrossAttentionSpec):
return all(
isinstance(spec, CrossAttentionSpec)
for spec in kv_cache_specs.values())
isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, SlidingWindowSpec):
return all(
isinstance(spec, SlidingWindowSpec)
and spec.sliding_window == one_spec.sliding_window
for spec in kv_cache_specs.values())
for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, ChunkedLocalAttentionSpec):
return all(
isinstance(spec, ChunkedLocalAttentionSpec)
and spec.attention_chunk_size == one_spec.attention_chunk_size
for spec in kv_cache_specs.values())
for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, MambaSpec):
return all(
isinstance(spec, MambaSpec) and spec.num_speculative_blocks ==
one_spec.num_speculative_blocks
for spec in kv_cache_specs.values())
isinstance(spec, MambaSpec)
and spec.num_speculative_blocks == one_spec.num_speculative_blocks
for spec in kv_cache_specs.values()
)
else:
# NOTE(Chen): Please add new branches for new KV cache spec types.
raise NotImplementedError(
f"Unsupported KV cache spec type: {type(one_spec)}")
f"Unsupported KV cache spec type: {type(one_spec)}"
)
@classmethod
def from_specs(cls, kv_cache_specs: dict[str,
KVCacheSpec]) -> Optional[Self]:
def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Optional[Self]:
"""
Return a SameTypeKVCacheSpecs object if all layers have the same type
of KV cache spec. Return None if not.
@@ -338,6 +359,7 @@ class KVCacheTensor:
"""
A class for specifying how the workers should initialize the KV cache.
"""
size: int # size of the KV cache tensor in bytes
shared_by: list[str] # layer names that share the same KV cache tensor
@@ -348,6 +370,7 @@ class KVCacheGroupSpec:
Represents a group of model layers that share the same KV cache block table.
These layers are regarded as one layer in the KV cache manager.
"""
# The names of model layers in this group
layer_names: list[str]
# The KV cache spec of this manager layer
@@ -359,6 +382,7 @@ class KVCacheConfig:
"""
The KV cache configuration of a model.
"""
"""The number of KV cache blocks"""
num_blocks: int
"""How should model runner initialize the KV cache tensors for each layer"""