[Misc] Move GPUModelRunner.prepare_kernel_block_sizes to utils (#35400)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -38,7 +38,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
from vllm.v1.worker.utils import AttentionGroup, select_common_block_size
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
NUM_BLOCKS = 10
|
||||
@@ -209,7 +209,7 @@ def test_select_common_block_size_prefers_manager_block_size():
|
||||
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||
]
|
||||
|
||||
selected_size = GPUModelRunner.select_common_block_size(128, attn_groups)
|
||||
selected_size = select_common_block_size(128, attn_groups)
|
||||
assert selected_size == 128
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ def test_select_common_block_size_uses_largest_shared_int():
|
||||
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
|
||||
]
|
||||
|
||||
selected_size = GPUModelRunner.select_common_block_size(256, attn_groups)
|
||||
selected_size = select_common_block_size(256, attn_groups)
|
||||
assert selected_size == 64
|
||||
|
||||
|
||||
@@ -234,7 +234,7 @@ def test_select_common_block_size_no_valid_option():
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
GPUModelRunner.select_common_block_size(48, attn_groups)
|
||||
select_common_block_size(48, attn_groups)
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner, dist_init):
|
||||
|
||||
@@ -115,7 +115,6 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder
|
||||
@@ -189,6 +188,7 @@ from .utils import (
|
||||
AttentionGroup,
|
||||
add_kv_sharing_layers_to_kv_cache_groups,
|
||||
bind_kv_cache,
|
||||
prepare_kernel_block_sizes,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
)
|
||||
|
||||
@@ -5678,78 +5678,6 @@ class GPUModelRunner(
|
||||
return
|
||||
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment]
|
||||
|
||||
@staticmethod
|
||||
def select_common_block_size(
|
||||
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
||||
) -> int:
|
||||
"""
|
||||
Select a block size that is supported by all backends and is a factor of
|
||||
kv_manager_block_size.
|
||||
|
||||
If kv_manager_block_size is supported by all backends, return it directly.
|
||||
Otherwise, return the max supported size.
|
||||
|
||||
Args:
|
||||
kv_manager_block_size: Block size of KV cache
|
||||
attn_groups: List of attention groups
|
||||
|
||||
Returns:
|
||||
The selected block size
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid block size found
|
||||
"""
|
||||
|
||||
def block_size_is_supported(
|
||||
backends: list[type[AttentionBackend]], block_size: int
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the block size is supported by all backends.
|
||||
"""
|
||||
for backend in backends:
|
||||
is_supported = False
|
||||
for supported_size in backend.get_supported_kernel_block_sizes():
|
||||
if isinstance(supported_size, int):
|
||||
if block_size == supported_size:
|
||||
is_supported = True
|
||||
elif isinstance(supported_size, MultipleOf):
|
||||
if block_size % supported_size.base == 0:
|
||||
is_supported = True
|
||||
else:
|
||||
raise ValueError(f"Unknown supported size: {supported_size}")
|
||||
if not is_supported:
|
||||
return False
|
||||
return True
|
||||
|
||||
backends = [group.backend for group in attn_groups]
|
||||
|
||||
# Case 1: if the block_size of kv cache manager is supported by all backends,
|
||||
# return it directly
|
||||
if block_size_is_supported(backends, kv_manager_block_size):
|
||||
return kv_manager_block_size
|
||||
|
||||
# Case 2: otherwise, the block_size must be an `int`-format supported size of
|
||||
# at least one backend. Iterate over all `int`-format supported sizes in
|
||||
# descending order and return the first one that is supported by all backends.
|
||||
# Simple proof:
|
||||
# If the supported size b is in MultipleOf(x_i) format for all attention
|
||||
# backends i, and b a factor of kv_manager_block_size, then
|
||||
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
|
||||
# return kv_manager_block_size in case 1.
|
||||
all_int_supported_sizes = set(
|
||||
supported_size
|
||||
for backend in backends
|
||||
for supported_size in backend.get_supported_kernel_block_sizes()
|
||||
if isinstance(supported_size, int)
|
||||
)
|
||||
|
||||
for supported_size in sorted(all_int_supported_sizes, reverse=True):
|
||||
if kv_manager_block_size % supported_size != 0:
|
||||
continue
|
||||
if block_size_is_supported(backends, supported_size):
|
||||
return supported_size
|
||||
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
|
||||
|
||||
def may_reinitialize_input_batch(
|
||||
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
|
||||
) -> None:
|
||||
@@ -5846,49 +5774,6 @@ class GPUModelRunner(
|
||||
for attn_groups in self.attn_groups:
|
||||
yield from attn_groups
|
||||
|
||||
def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]:
|
||||
"""
|
||||
Generate kernel_block_sizes that matches each block_size.
|
||||
|
||||
For attention backends that support virtual block splitting,
|
||||
use the supported block sizes from the backend.
|
||||
For other backends (like Mamba), use the same block size (no splitting).
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
|
||||
Returns:
|
||||
list[int]: List of kernel block sizes for each cache group.
|
||||
"""
|
||||
kernel_block_sizes = []
|
||||
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
# All layers in the UniformTypeKVCacheSpecs have the same type,
|
||||
# Pick an arbitrary one to dispatch.
|
||||
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
|
||||
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
continue
|
||||
elif isinstance(kv_cache_spec, AttentionSpec):
|
||||
# This is an attention backend that supports virtual
|
||||
# block splitting. Get the supported block sizes from
|
||||
# all backends in the group.
|
||||
attn_groups = self.attn_groups[kv_cache_gid]
|
||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||
selected_kernel_size = self.select_common_block_size(
|
||||
kv_manager_block_size, attn_groups
|
||||
)
|
||||
kernel_block_sizes.append(selected_kernel_size)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
# This is likely Mamba or other non-attention cache,
|
||||
# no splitting.
|
||||
kernel_block_sizes.append(kv_cache_spec.block_size)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
|
||||
)
|
||||
return kernel_block_sizes
|
||||
|
||||
def _reshape_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -6120,7 +6005,9 @@ class GPUModelRunner(
|
||||
# backends for that group only supports block_size 64, we will return
|
||||
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
|
||||
# tokens each.
|
||||
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
|
||||
kernel_block_sizes = prepare_kernel_block_sizes(
|
||||
kv_cache_config, self.attn_groups
|
||||
)
|
||||
|
||||
# create metadata builders
|
||||
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
|
||||
|
||||
@@ -13,8 +13,20 @@ from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadataBuilder,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -59,6 +71,119 @@ class AttentionGroup:
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
def select_common_block_size(
|
||||
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
||||
) -> int:
|
||||
"""
|
||||
Select a block size that is supported by all backends and is a factor of
|
||||
kv_manager_block_size.
|
||||
|
||||
If kv_manager_block_size is supported by all backends, return it directly.
|
||||
Otherwise, return the max supported size.
|
||||
|
||||
Args:
|
||||
kv_manager_block_size: Block size of KV cache.
|
||||
attn_groups: List of attention groups.
|
||||
|
||||
Returns:
|
||||
The selected block size.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid block size found.
|
||||
"""
|
||||
|
||||
def block_size_is_supported(
|
||||
backends: list[type[AttentionBackend]], block_size: int
|
||||
) -> bool:
|
||||
"""Check if the block size is supported by all backends."""
|
||||
for backend in backends:
|
||||
is_supported = False
|
||||
for supported_size in backend.get_supported_kernel_block_sizes():
|
||||
if isinstance(supported_size, int):
|
||||
if block_size == supported_size:
|
||||
is_supported = True
|
||||
elif isinstance(supported_size, MultipleOf):
|
||||
if block_size % supported_size.base == 0:
|
||||
is_supported = True
|
||||
else:
|
||||
raise ValueError(f"Unknown supported size: {supported_size}")
|
||||
if not is_supported:
|
||||
return False
|
||||
return True
|
||||
|
||||
backends = [group.backend for group in attn_groups]
|
||||
|
||||
# Case 1: if the block_size of kv cache manager is supported by all backends,
|
||||
# return it directly.
|
||||
if block_size_is_supported(backends, kv_manager_block_size):
|
||||
return kv_manager_block_size
|
||||
|
||||
# Case 2: otherwise, the block_size must be an `int`-format supported size of
|
||||
# at least one backend. Iterate over all `int`-format supported sizes in
|
||||
# descending order and return the first one that is supported by all backends.
|
||||
# Simple proof:
|
||||
# If the supported size b is in MultipleOf(x_i) format for all attention
|
||||
# backends i, and b a factor of kv_manager_block_size, then
|
||||
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
|
||||
# return kv_manager_block_size in case 1.
|
||||
all_int_supported_sizes = set(
|
||||
supported_size
|
||||
for backend in backends
|
||||
for supported_size in backend.get_supported_kernel_block_sizes()
|
||||
if isinstance(supported_size, int)
|
||||
)
|
||||
|
||||
for supported_size in sorted(all_int_supported_sizes, reverse=True):
|
||||
if kv_manager_block_size % supported_size != 0:
|
||||
continue
|
||||
if block_size_is_supported(backends, supported_size):
|
||||
return supported_size
|
||||
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
|
||||
|
||||
|
||||
def prepare_kernel_block_sizes(
|
||||
kv_cache_config: KVCacheConfig, attn_groups: list[list[AttentionGroup]]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Generate kernel_block_sizes that matches each block_size.
|
||||
|
||||
For attention backends that support virtual block splitting,
|
||||
use the supported block sizes from the backend.
|
||||
For other backends (like Mamba), use the same block size (no splitting).
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
attn_groups: Attention groups indexed by KV cache group id.
|
||||
|
||||
Returns:
|
||||
List of kernel block sizes for each cache group.
|
||||
"""
|
||||
kernel_block_sizes = []
|
||||
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
# All layers in the UniformTypeKVCacheSpecs have the same type,
|
||||
# pick an arbitrary one to dispatch.
|
||||
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
|
||||
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
continue
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
# This is an attention backend that supports virtual block splitting.
|
||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||
selected_kernel_size = select_common_block_size(
|
||||
kv_manager_block_size, attn_groups[kv_cache_gid]
|
||||
)
|
||||
kernel_block_sizes.append(selected_kernel_size)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
# This is likely Mamba or other non-attention cache, no splitting.
|
||||
kernel_block_sizes.append(kv_cache_spec.block_size)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
|
||||
)
|
||||
return kernel_block_sizes
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
expected_num_items: int,
|
||||
|
||||
Reference in New Issue
Block a user