diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 0fbd6605a..d7695027a 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -851,6 +851,7 @@ def test_hybrid_attention_mamba_tensor_shapes(): vllm_ctx = vllm_config.compilation_config.static_forward_context runner = GPUModelRunner(vllm_config, DEVICE) + current_platform.update_block_size_for_backend(vllm_config) kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes @@ -1306,6 +1307,7 @@ def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks(): assert fwd_context is not None runner = GPUModelRunner(vllm_config, DEVICE) + current_platform.update_block_size_for_backend(vllm_config) kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 1cadb4318..49c8868e7 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -38,6 +38,8 @@ class CacheConfig: Accepts None (meaning "use default"). After construction, always int.""" user_specified_block_size: bool = field(default=False, init=False) """Whether block_size was explicitly provided. Derived automatically.""" + user_specified_mamba_block_size: bool = field(default=False, init=False) + """Whether mamba_block_size was explicitly provided. Derived automatically.""" gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) """The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory @@ -182,6 +184,7 @@ class CacheConfig: "cpu_kvcache_space_bytes", "mamba_page_size_padded", "user_specified_block_size", + "user_specified_mamba_block_size", "_block_size_resolved", # Post-init/derived counters "num_gpu_blocks", @@ -214,6 +217,8 @@ class CacheConfig: object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE) else: object.__setattr__(self, "user_specified_block_size", True) + if self.mamba_block_size is not None: + object.__setattr__(self, "user_specified_mamba_block_size", True) return self @field_validator("calculate_kv_scales", mode="after") diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index a5644a414..03b147e5c 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -1,15 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from math import lcm from typing import TYPE_CHECKING from vllm.logger import init_logger -from vllm.model_executor.models import ModelRegistry -from vllm.utils.math_utils import cdiv, round_up -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec +from vllm.utils.math_utils import round_up if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -104,11 +99,11 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ - Ensure that page size of attention layers is greater than or - equal to the mamba layers. If not, automatically set the attention - block size to ensure that it is. If the attention page size is - strictly greater than the mamba page size, we pad the mamba page size - to make them equal. + Perform early validation and setup for hybrid attention/mamba models. + + Block size alignment with mamba page sizes is handled later by + Platform.update_block_size_for_backend(), which runs after model + layers are constructed and the attention backend is known. Args: vllm_config: vLLM Config @@ -118,6 +113,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): # Disable calculate_kv_scales for hybrid models: uninitialized # recurrent state corrupts scales during the calibration pass. # See issue: https://github.com/vllm-project/vllm/issues/37554 + if cache_config.calculate_kv_scales: logger.warning( "Disabling calculate_kv_scales for hybrid model '%s'. " @@ -129,140 +125,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): ) cache_config.calculate_kv_scales = False - # Save the user input before it gets modified by MambaModelConfig - mamba_block_size = cache_config.mamba_block_size # Enable FULL_AND_PIECEWISE by default MambaModelConfig.verify_and_update_config(vllm_config) - attention_config = vllm_config.attention_config - cache_config = vllm_config.cache_config - model_config = vllm_config.model_config - parallel_config = vllm_config.parallel_config - - if cache_config.cache_dtype == "auto": - kv_cache_dtype = model_config.dtype - else: - kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - - # get attention page size (for 1 token) - # Attention backend constraints: - # - FlashAttention (FA) requires block size to be multiple of 16 - # - MLA (Multi-head Latent Attention) requires larger alignment: - # * CUTLASS_MLA backend: kernel_block_size 128 alignment - # * Other MLA backends: kernel_block_size 64 alignment - if model_config.use_mla: - use_cutlass_mla = ( - attention_config.backend == AttentionBackendEnum.CUTLASS_MLA - ) - kernel_block_alignment_size = 128 if use_cutlass_mla else 64 - attn_page_size_1_token = MLAAttentionSpec( - block_size=1, - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - ).page_size_bytes - else: - kernel_block_alignment_size = 16 - attn_page_size_1_token = FullAttentionSpec( - block_size=1, - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - ).page_size_bytes - - model_cls, _ = ModelRegistry.resolve_model_cls( - model_config.architecture, - model_config=model_config, - ) - - # get mamba page size - mamba_page_size = MambaSpec( - shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), - dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), - block_size=-1, # block_size doesn't matter for mamba page size - ).page_size_bytes - - # Model may be marked as is_hybrid - # but mamba is skipped via config, - # return directly - if mamba_page_size == 0: - return - - if cache_config.mamba_cache_mode == "all": - # With prefix caching, select attention block size to - # optimize for mamba kernel performance - - # Mamba2 SSD kernel uses a chunk_size, e.g. 256 - # Align the block to the kernel: use lowest multiple of chunk_size - # of attention tokens that would fit mamba_page_size: - # e.g. for mamba page size = 788kB - # attn_1_token = 2kB -> fits ~394 tokens - # then round up to a multiple of 256 -> 512 tokens - # End result: - # attn_block_size = 512 - # mamba_block_size = 512 (aligned to a multiple of chunk_size) - # TODO(tdoublep): this constraint can be relaxed fairly - # easily by changing the way we layout chunks in the - # mamba2 kernels. - - base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() - attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) - chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) - attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) - cache_config.mamba_block_size = attn_block_size - else: - # Without prefix caching, select minimum valid attention block size - # to minimize mamba state padding - - # Calculate minimum attention block size that satisfies both: - # 1. Backend alignment requirements (kernel_block_alignment_size) - # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) - attn_block_size = kernel_block_alignment_size * cdiv( - mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token - ) - - # override attention block size if it is too small, - # even if the user has explicitly set it - if cache_config.block_size < attn_block_size: - cache_config.block_size = attn_block_size - logger.info( - "Setting attention block size to %d tokens " - "to ensure that attention page size is >= mamba page size.", - attn_block_size, - ) - - # By default, mamba block size will be set to max_model_len. - # When enabling prefix caching and using align mamba cache - # mode, we align mamba block size to the block size as the - # basic granularity for prefix caching. - if cache_config.mamba_cache_mode == "align": - cache_config.mamba_block_size = cache_config.block_size - - # compute new attention page size - attn_page_size = cache_config.block_size * attn_page_size_1_token - - assert attn_page_size >= mamba_page_size - - if attn_page_size == mamba_page_size: - # don't need to pad mamba page size - return - - # pad mamba page size to exactly match attention - if ( - cache_config.mamba_page_size_padded is None - or cache_config.mamba_page_size_padded != attn_page_size - ): - cache_config.mamba_page_size_padded = attn_page_size - mamba_padding_pct = ( - 100 * (attn_page_size - mamba_page_size) / mamba_page_size - ) - logger.info( - "Padding mamba page size by %.2f%% to ensure " - "that mamba page size and attention page size are " - "exactly equal.", - mamba_padding_pct, - ) - class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 281e91999..fae37442e 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser + from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.selector import AttentionSelectorConfig else: FlexibleArgumentParser = object @@ -424,29 +425,11 @@ class Platform: pass @classmethod - def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: - """ - Ensure block_size is compatible with the attention backend. - """ - from vllm.config.cache import CacheConfig - - cache_config = vllm_config.cache_config - if cache_config.user_specified_block_size: - # User specified --block-size; keep it. - return - - model_config = vllm_config.model_config - # model_config may be None during testing. - # Skip hybrid models — their block_size is managed by - # HybridAttentionMambaModelConfig. - if model_config is None or model_config.is_hybrid: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE - return - - from vllm.config.vllm import ( - get_layers_from_vllm_config, - set_current_vllm_config, - ) + def _find_non_ssm_backend( + cls, vllm_config: "VllmConfig" + ) -> "type[AttentionBackend] | None": + """Find the first non-SSM attention backend from model layers.""" + from vllm.config.vllm import get_layers_from_vllm_config from vllm.model_executor.layers.attention_layer_base import ( AttentionLayerBase, ) @@ -455,23 +438,181 @@ class Platform: vllm_config, AttentionLayerBase, # type: ignore[type-abstract] ) - if not attn_layers: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + for layer in attn_layers.values(): + b = layer.get_attn_backend() + if not b.is_ssm(): + return b + return None + + @classmethod + def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: + """ + Ensure block_size is compatible with the attention backend. + For hybrid models, also aligns block_size with mamba page sizes. + """ + from vllm.config.cache import CacheConfig + from vllm.config.vllm import set_current_vllm_config + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + + # model_config may be None during testing. + if not model_config: return - first_layer = next(iter(attn_layers.values())) - backend_cls = first_layer.get_attn_backend() + backend_cls = cls._find_non_ssm_backend(vllm_config) + if backend_cls is None: + return + + # Phase 1: Pick block size from backend (skip if user set --block-size) + if not cache_config.user_specified_block_size: + with set_current_vllm_config(vllm_config): + preferred = backend_cls.get_preferred_block_size( + CacheConfig.DEFAULT_BLOCK_SIZE + ) + if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: + logger.info( + "Setting kv cache block size to %d for %s backend.", + preferred, + backend_cls.get_name(), + ) + cache_config.block_size = preferred + + # Phase 2: Align block/mamba sizes for hybrid models + # (may override user settings). + if model_config.is_hybrid: + cls._align_hybrid_block_size(vllm_config, backend_cls) + + @classmethod + def _align_hybrid_block_size( + cls, + vllm_config: "VllmConfig", + backend_cls: "type[AttentionBackend]", + ) -> None: + """ + For hybrid attention/mamba models, ensure that the attention page + size is >= the mamba page size, and pad the mamba page size to match. + """ + from math import lcm + + from vllm.config.vllm import set_current_vllm_config + from vllm.model_executor.models import ModelRegistry + from vllm.utils.math_utils import cdiv + from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE + from vllm.v1.attention.backend import MultipleOf + from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + MambaSpec, + MLAAttentionSpec, + ) + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Compute attention page size for 1 token + if model_config.use_mla: + attn_page_size_1_token = MLAAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + else: + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + + # Compute mamba page size + model_cls, _ = ModelRegistry.resolve_model_cls( + model_config.architecture, + model_config=model_config, + ) + mamba_page_size = MambaSpec( + shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), + block_size=-1, + ).page_size_bytes + + if mamba_page_size == 0: + return + + # mamba_block_size here should either be user specified value or None + mamba_block_size = ( + cache_config.mamba_block_size + if cache_config.user_specified_mamba_block_size + else None + ) + + # Get kernel block alignment from the backend's supported sizes with set_current_vllm_config(vllm_config): - preferred = backend_cls.get_preferred_block_size( - CacheConfig.DEFAULT_BLOCK_SIZE + kernel_block_alignment_size = max( + min( + s.base if isinstance(s, MultipleOf) else s + for s in backend_cls.get_supported_kernel_block_sizes() + ), + cache_config.block_size, ) - if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: + + if cache_config.mamba_cache_mode == "all": + # With prefix caching, align to mamba chunk size for kernel perf + # TODO(tdoublep): this constraint can be relaxed fairly + # easily by changing the way we layout chunks in the + # mamba2 kernels. + base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() + assert base_chunk_size is not None + attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) + chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) + attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) + cache_config.mamba_block_size = attn_block_size + else: + # Without prefix caching, use minimum block size that satisfies + # both backend alignment and mamba page size compatibility + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, + kernel_block_alignment_size * attn_page_size_1_token, + ) + + if cache_config.block_size < attn_block_size: + cache_config.block_size = attn_block_size logger.info( - "Setting kv cache block size to %d for %s backend.", - preferred, - backend_cls.get_name(), + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size, + ) + + if cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = cache_config.block_size + + # Pad mamba page size to exactly match attention page size + attn_page_size = cache_config.block_size * attn_page_size_1_token + assert attn_page_size >= mamba_page_size + + if attn_page_size == mamba_page_size: + return + + if ( + cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size + ): + cache_config.mamba_page_size_padded = attn_page_size + mamba_padding_pct = ( + 100 * (attn_page_size - mamba_page_size) / mamba_page_size + ) + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", + mamba_padding_pct, ) - cache_config.block_size = preferred @classmethod def verify_model_arch(cls, model_arch: str) -> None: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index b8cab5f45..bf96b94af 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -160,11 +160,7 @@ class XPUPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config - # in V1(or with chunked prefill) block_size is 64 - if cache_config and not cache_config.user_specified_block_size: - cache_config.block_size = 64 # lazy import to avoid circular import from vllm.config import CUDAGraphMode @@ -221,12 +217,6 @@ class XPUPlatform(Platform): # ref. https://openucx.readthedocs.io/en/master/faq.html os.environ["UCX_MEMTYPE_CACHE"] = "n" - @classmethod - def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: - # TODO: XPU still sets block_size in check_and_update_config. - # Move that logic here so block_size is chosen by the backend. - pass - @classmethod def support_hybrid_kv_cache(cls) -> bool: return True diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index cd49ea30e..9001b23f3 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -311,6 +311,10 @@ class AttentionBackend(ABC): def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": return None + @classmethod + def is_ssm(cls) -> bool: + return False + class AttentionMetadata: pass diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 245995be2..5e63fa592 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -43,6 +43,7 @@ from vllm.config import ( from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv, round_up from vllm.v1.attention.backend import ( @@ -90,6 +91,12 @@ class FlashAttentionBackend(AttentionBackend): forward_includes_kv_cache_update: bool = False + @classmethod + def get_preferred_block_size(cls, default_block_size: int) -> int: + if current_platform.is_xpu(): + return max(default_block_size, 64) + return super().get_preferred_block_size(default_block_size) + @staticmethod def get_name() -> str: return "FLASH_ATTN" diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 574cc87e7..f65d9a4b3 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -31,6 +31,10 @@ class GDNAttentionBackend(AttentionBackend): def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: return GDNAttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class GDNAttentionMetadata: diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index fe27e7a38..b2ca15198 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -27,6 +27,10 @@ class LinearAttentionBackend(AttentionBackend): def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: return LinearAttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class LinearAttentionMetadata: diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 890340620..925fceb02 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -20,6 +20,10 @@ class Mamba1AttentionBackend(AttentionBackend): def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: return Mamba1AttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class Mamba1AttentionMetadata(BaseMambaAttentionMetadata): diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 5e8abbab5..fa7d4bd2e 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -96,6 +96,10 @@ class Mamba2AttentionBackend(AttentionBackend): def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]: return Mamba2AttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class Mamba2AttentionMetadata(BaseMambaAttentionMetadata): diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index c6a8e6eea..9c85ec5ef 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -18,6 +18,10 @@ class ShortConvAttentionBackend(AttentionBackend): def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: return ShortConvAttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class ShortConvAttentionMetadata(BaseMambaAttentionMetadata):