diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 7ac1951fe..347205755 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -6,7 +6,12 @@ from unittest.mock import patch import pytest import torch -from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config +from vllm.config import ( + AttentionConfig, + CacheConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.platforms import current_platform from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cuda import CudaPlatform @@ -84,12 +89,15 @@ def test_backend_selection( """Test attention backend selection with valid device-backend pairs.""" # Create AttentionConfig with the specified backend attention_config = AttentionConfig(backend=AttentionBackendEnum[name]) - vllm_config = VllmConfig(attention_config=attention_config) + cache_config = CacheConfig(block_size=block_size) + vllm_config = VllmConfig( + attention_config=attention_config, cache_config=cache_config + ) with set_current_vllm_config(vllm_config): if device == "cpu": with patch("vllm.platforms.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float16, None, block_size) + backend = get_attn_backend(16, torch.float16, None) assert backend.get_name() == "CPU_ATTN" elif device == "hip": @@ -104,20 +112,16 @@ def test_backend_selection( if name == "TRITON_MLA" and block_size == 1: # TRITON_MLA doesn't support block_size == 1 with pytest.raises(ValueError): - get_attn_backend( - 576, torch.float16, None, block_size, use_mla=use_mla - ) + get_attn_backend(576, torch.float16, None, use_mla=use_mla) else: # Valid backend-block_size combination backend = get_attn_backend( - 576, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, use_mla=use_mla ) expected = name assert backend.get_name() == expected else: - backend = get_attn_backend( - 32, torch.float16, None, block_size, use_mla=use_mla - ) + backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla) expected = "ROCM_ATTN" assert backend.get_name() == expected @@ -141,7 +145,7 @@ def test_backend_selection( if capability[0] != 10: pytest.skip("CUTLASS MLA is not supported on this platform") backend = get_attn_backend( - 576, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, use_mla=use_mla ) expected = "CUTLASS_MLA" assert backend.get_name() == expected @@ -156,7 +160,7 @@ def test_backend_selection( "FlashInfer MLA only supports block_size 32 or 64" ) backend = get_attn_backend( - 576, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, use_mla=use_mla ) expected = "FLASHINFER_MLA" assert backend.get_name() == expected @@ -175,7 +179,6 @@ def test_backend_selection( 576, torch.float16, None, - block_size, use_mla=use_mla, ) expected = name @@ -190,27 +193,23 @@ def test_backend_selection( "FlashAttention MLA not supported on this platform" ) backend = get_attn_backend( - 576, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, use_mla=use_mla ) expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected else: # TRITON_MLA or other fallback backend = get_attn_backend( - 576, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, use_mla=use_mla ) expected = "TRITON_MLA" assert backend.get_name() == expected elif name == "FLASHINFER": - backend = get_attn_backend( - 64, torch.float16, None, block_size, use_mla=use_mla - ) + backend = get_attn_backend(64, torch.float16, None, use_mla=use_mla) expected = "FLASHINFER" assert backend.get_name() == expected elif name == "FLASH_ATTN": - backend = get_attn_backend( - 32, torch.float16, None, block_size, use_mla=use_mla - ) + backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla) expected = "FLASH_ATTN" assert backend.get_name() == expected @@ -224,12 +223,12 @@ def test_fp32_fallback(device: str): with set_current_vllm_config(vllm_config): if device == "cpu": with patch("vllm.platforms.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float32, None, 16) + backend = get_attn_backend(16, torch.float32, None) assert backend.get_name() == "CPU_ATTN" elif device == "cuda": with patch("vllm.platforms.current_platform", CudaPlatform()): - backend = get_attn_backend(16, torch.float32, None, 16) + backend = get_attn_backend(16, torch.float32, None) assert backend.get_name() == "FLEX_ATTENTION" @@ -241,35 +240,40 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): ) attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN) - vllm_config = VllmConfig(attention_config=attention_config) + cache_config = CacheConfig(block_size=16) + vllm_config = VllmConfig( + attention_config=attention_config, cache_config=cache_config + ) with set_current_vllm_config(vllm_config): # Unsupported CUDA arch monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) - backend = get_attn_backend(16, torch.float16, None, 16) + backend = get_attn_backend(16, torch.float16, None) assert backend.get_name() != "FLASH_ATTN" # Reset the monkeypatch for subsequent tests monkeypatch.undo() # Unsupported data type - backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16) + backend = get_attn_backend(16, torch.float8_e4m3fn, None) assert backend.get_name() != "FLASH_ATTN" # Unsupported kv cache data type - backend = get_attn_backend(16, torch.float16, "fp8", 16) + backend = get_attn_backend(16, torch.float16, "fp8") assert backend.get_name() != "FLASH_ATTN" # Unsupported block size - backend = get_attn_backend(16, torch.float16, None, 8) + vllm_config.cache_config.block_size = 8 + backend = get_attn_backend(16, torch.float16, None) assert backend.get_name() != "FLASH_ATTN" # flash-attn is not installed import sys + vllm_config.cache_config.block_size = 16 original_module = sys.modules.get("vllm_flash_attn") monkeypatch.setitem(sys.modules, "vllm_flash_attn", None) - backend = get_attn_backend(16, torch.float16, None, 16) + backend = get_attn_backend(16, torch.float16, None) assert backend.get_name() != "FLASH_ATTN" # Restore the original module if it existed @@ -279,7 +283,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False) # Unsupported head size - backend = get_attn_backend(17, torch.float16, None, 16) + backend = get_attn_backend(17, torch.float16, None) assert backend.get_name() != "FLASH_ATTN" @@ -320,7 +324,7 @@ def test_auto_backend_selection_behavior(): set_current_vllm_config(vllm_config_auto), patch("vllm.platforms.current_platform", CpuPlatform()), ): - backend_auto = get_attn_backend(16, torch.float16, None, 16) + backend_auto = get_attn_backend(16, torch.float16, None) _cached_get_attn_backend.cache_clear() @@ -328,7 +332,7 @@ def test_auto_backend_selection_behavior(): set_current_vllm_config(vllm_config_none), patch("vllm.platforms.current_platform", CpuPlatform()), ): - backend_none = get_attn_backend(16, torch.float16, None, 16) + backend_none = get_attn_backend(16, torch.float16, None) # Both should select the same backend assert backend_auto.get_name() == backend_none.get_name() @@ -358,7 +362,10 @@ def test_per_head_quant_scales_backend_selection( backend=AttentionBackendEnum[backend_name], flash_attn_version=flash_attn_version, ) - vllm_config = VllmConfig(attention_config=attention_config) + cache_config = CacheConfig(block_size=64) + vllm_config = VllmConfig( + attention_config=attention_config, cache_config=cache_config + ) with ( set_current_vllm_config(vllm_config), @@ -376,7 +383,6 @@ def test_per_head_quant_scales_backend_selection( head_size=128, dtype=torch.float16, kv_cache_dtype="fp8", - block_size=64, use_per_head_quant_scales=True, ) assert backend.get_name() == backend_name @@ -386,7 +392,6 @@ def test_per_head_quant_scales_backend_selection( head_size=128, dtype=torch.float16, kv_cache_dtype="fp8", - block_size=64, use_per_head_quant_scales=True, ) assert backend_name in str(exc_info.value) diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index b53536814..5afcab9f3 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -13,6 +13,7 @@ import torch.nn as nn from PIL import Image from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config +from vllm.config.cache import CacheConfig from vllm.config.multimodal import ( AudioDummyOptions, BaseDummyOptions, @@ -131,7 +132,9 @@ def initialize_dummy_model( ): temp_file = tempfile.mkstemp()[1] current_device = torch.get_default_device() - vllm_config = VllmConfig(model_config=model_config) + vllm_config = VllmConfig( + model_config=model_config, cache_config=CacheConfig(block_size=16) + ) with set_current_vllm_config(vllm_config=vllm_config): init_distributed_environment( world_size=1, diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 963ab6f1d..6ac68e055 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -80,7 +80,7 @@ def _create_proposer( device = current_platform.device_type vllm_config = VllmConfig( model_config=model_config, - cache_config=CacheConfig(), + cache_config=CacheConfig(block_size=16), speculative_config=speculative_config, device_config=DeviceConfig(device=device), parallel_config=ParallelConfig(), diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 71603d8c8..3796265ff 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -2,16 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import field -from typing import Literal +from typing import ClassVar, Literal -from pydantic import Field, SkipValidation, field_validator +from pydantic import Field, SkipValidation, field_validator, model_validator from vllm.config.utils import config from vllm.logger import init_logger logger = init_logger(__name__) -BlockSize = Literal[1, 8, 16, 32, 64, 128, 256] CacheDType = Literal[ "auto", "bfloat16", @@ -31,12 +30,13 @@ KVOffloadingBackend = Literal["native", "lmcache"] class CacheConfig: """Configuration for the KV cache.""" - block_size: SkipValidation[BlockSize] = None # type: ignore[assignment] - """Size of a contiguous cache block in number of tokens. + DEFAULT_BLOCK_SIZE: ClassVar[int] = 16 - This config has no static default. If left unspecified by the user, it will - be set in `Platform.check_and_update_config()` based on the current - platform.""" + block_size: SkipValidation[int] = None # type: ignore[assignment] + """Size of a contiguous cache block in number of tokens. + Accepts None (meaning "use default"). After construction, always int.""" + user_specified_block_size: bool = field(default=False, init=False) + """Whether block_size was explicitly provided. Derived automatically.""" gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) """The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory @@ -169,6 +169,8 @@ class CacheConfig: "prefix_caching_hash_algo", "cpu_kvcache_space_bytes", "mamba_page_size_padded", + "user_specified_block_size", + "_block_size_resolved", # Post-init/derived counters "num_gpu_blocks", "num_cpu_blocks", @@ -186,6 +188,22 @@ class CacheConfig: # metrics info return {key: str(value) for key, value in self.__dict__.items()} + _block_size_resolved: bool = field(default=False, init=False) + """Guard against pydantic re-running _apply_block_size_default.""" + + @model_validator(mode="after") + def _apply_block_size_default(self) -> "CacheConfig": + # Pydantic re-runs validators when CacheConfig is nested inside + # another pydantic model (e.g. VllmConfig). Guard against that. + if self._block_size_resolved: + return self + object.__setattr__(self, "_block_size_resolved", True) + if self.block_size is None: + object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE) + else: + object.__setattr__(self, "user_specified_block_size", True) + return self + @field_validator("cache_dtype", mode="after") @classmethod def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index bf8620b73..682feff11 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1026,32 +1026,6 @@ class VllmConfig: ) current_platform.check_and_update_config(self) - # If DCP, ensure the block size is right. - if self.parallel_config.decode_context_parallel_size > 1: - if self.parallel_config.dcp_kv_cache_interleave_size > 1 and ( - self.parallel_config.cp_kv_cache_interleave_size - != self.parallel_config.dcp_kv_cache_interleave_size - ): - self.parallel_config.cp_kv_cache_interleave_size = ( - self.parallel_config.dcp_kv_cache_interleave_size - ) - logger.warning_once( - "cp_kv_cache_interleave_size is overridden by dcp_kv_cache" - "_interleave_size. And dcp-kv-cache-interleave-size will be " - "deprecated when PCP is fully supported." - ) - assert ( - self.parallel_config.cp_kv_cache_interleave_size - <= self.cache_config.block_size - and self.cache_config.block_size - % self.parallel_config.cp_kv_cache_interleave_size - == 0 - ), ( - f"Block_size({self.cache_config.block_size}) should be greater " - "than or equal to and divisible by cp_kv_cache_interleave_size " - f"({self.parallel_config.cp_kv_cache_interleave_size})." - ) - # Do this after all the updates to compilation_config.mode effective_dp_size = ( self.parallel_config.data_parallel_size @@ -1219,26 +1193,6 @@ class VllmConfig: # Default to enable HMA if not explicitly disabled by user or logic above. self.scheduler_config.disable_hybrid_kv_cache_manager = False - if self.cache_config.mamba_cache_mode == "align": - assert ( - self.cache_config.block_size - <= self.scheduler_config.max_num_batched_tokens - ), ( - "In Mamba cache align mode, block_size " - f"({self.cache_config.block_size}) must be <= " - "max_num_batched_tokens " - f"({self.scheduler_config.max_num_batched_tokens})." - ) - if self.scheduler_config.long_prefill_token_threshold > 0: - assert ( - self.scheduler_config.long_prefill_token_threshold - >= self.cache_config.block_size - ) - assert not self.scheduler_config.disable_chunked_mm_input, ( - "Chunked MM input is required because we need the flexibility to " - "schedule a multiple of block_size tokens even if they are in the " - "middle of a mm input" - ) if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path.absolute().expanduser() @@ -1673,6 +1627,53 @@ class VllmConfig: f"compilation_config={self.compilation_config!r}" ) + def validate_block_size(self) -> None: + """Validate block_size against DCP and mamba constraints. + + Called after Platform.update_block_size_for_backend() has + finalised block_size. + """ + block_size = self.cache_config.block_size + + # DCP interleave-size compatibility + if self.parallel_config.decode_context_parallel_size > 1: + if self.parallel_config.dcp_kv_cache_interleave_size > 1 and ( + self.parallel_config.cp_kv_cache_interleave_size + != self.parallel_config.dcp_kv_cache_interleave_size + ): + self.parallel_config.cp_kv_cache_interleave_size = ( + self.parallel_config.dcp_kv_cache_interleave_size + ) + logger.warning_once( + "cp_kv_cache_interleave_size is overridden by dcp_kv_cache" + "_interleave_size. And dcp-kv-cache-interleave-size will be " + "deprecated when PCP is fully supported." + ) + assert ( + self.parallel_config.cp_kv_cache_interleave_size <= block_size + and block_size % self.parallel_config.cp_kv_cache_interleave_size == 0 + ), ( + f"Block_size({block_size}) should be greater " + "than or equal to and divisible by cp_kv_cache_interleave_size " + f"({self.parallel_config.cp_kv_cache_interleave_size})." + ) + + # Mamba cache align-mode constraints + if self.cache_config.mamba_cache_mode == "align": + assert block_size <= self.scheduler_config.max_num_batched_tokens, ( + "In Mamba cache align mode, block_size " + f"({block_size}) must be <= " + "max_num_batched_tokens " + f"({self.scheduler_config.max_num_batched_tokens})." + ) + if self.scheduler_config.long_prefill_token_threshold > 0: + assert self.scheduler_config.long_prefill_token_threshold >= block_size + assert not self.scheduler_config.disable_chunked_mm_input, ( + "Chunked MM input is required because we need the flexibility " + "to schedule a multiple of block_size tokens even if they are " + "in the middle of a mm input" + ) + @model_validator(mode="after") def validate_mamba_block_size(self) -> "VllmConfig": if self.model_config is None: diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index eb93ea324..6e0366c52 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -500,7 +500,6 @@ def get_current_attn_backend(vllm_config: VllmConfig): head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, kv_cache_dtype=vllm_config.cache_config.cache_dtype, - block_size=vllm_config.cache_config.block_size, use_mla=vllm_config.model_config.use_mla, ) return backend diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index 2494857c6..800b24c0a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -726,7 +726,6 @@ class MoRIIOConnectorWorker: self.model_config.get_head_size(), self.model_config.dtype, self.cache_config.cache_dtype, - self.block_size, use_mla=self.use_mla, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index dc1735a01..c31e17299 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -62,7 +62,6 @@ from vllm.config import ( get_attr_docs, ) from vllm.config.cache import ( - BlockSize, CacheDType, KVOffloadingBackend, MambaCacheMode, @@ -440,7 +439,7 @@ class EngineArgs: max_parallel_loading_workers: int | None = ( ParallelConfig.max_parallel_loading_workers ) - block_size: BlockSize = CacheConfig.block_size + block_size: int | None = None enable_prefix_caching: bool | None = None prefix_caching_hash_algo: PrefixCachingHashAlgo = ( CacheConfig.prefix_caching_hash_algo @@ -1521,7 +1520,7 @@ class EngineArgs: ) cache_config = CacheConfig( - block_size=self.block_size, + block_size=self.block_size, # type: ignore[arg-type] gpu_memory_utilization=self.gpu_memory_utilization, kv_cache_memory_bytes=self.kv_cache_memory_bytes, cache_dtype=resolved_cache_dtype, # type: ignore[arg-type] diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 38f10998e..1ab22d408 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -221,11 +221,9 @@ class Attention(nn.Module, AttentionLayerBase): vllm_config = get_current_vllm_config() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" - block_size = 16 calculate_kv_scales = False # llm-compressor mdls need to set cache_dtype to "fp8" manually. @@ -275,7 +273,6 @@ class Attention(nn.Module, AttentionLayerBase): head_size, dtype, kv_cache_dtype, - block_size, use_mla=False, has_sink=self.has_sink, use_mm_prefix=self.use_mm_prefix, diff --git a/vllm/model_executor/layers/attention/chunked_local_attention.py b/vllm/model_executor/layers/attention/chunked_local_attention.py index e33733c0c..b747304ac 100644 --- a/vllm/model_executor/layers/attention/chunked_local_attention.py +++ b/vllm/model_executor/layers/attention/chunked_local_attention.py @@ -30,9 +30,8 @@ from vllm.v1.kv_cache_interface import ( def create_chunked_local_attention_backend( underlying_attn_backend: AttentionBackend, attention_chunk_size: int, - block_size: int, ) -> type[AttentionBackend]: - prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" + prefix = f"ChunkedLocalAttention_{attention_chunk_size}_" underlying_builder = underlying_attn_backend.get_builder_cls() assert issubclass(underlying_builder, AttentionMetadataBuilder) @@ -55,7 +54,9 @@ def create_chunked_local_attention_backend( fast_build: bool = False, ): cm, make_virtual_batches_block_table = make_local_attention_virtual_batches( - attention_chunk_size, common_attn_metadata, block_size + attention_chunk_size, + common_attn_metadata, + self.kv_cache_spec.block_size, ) metadata = super().build(common_prefix_len, cm, fast_build) metadata.make_virtual_batches_block_table = make_virtual_batches_block_table @@ -94,16 +95,12 @@ class ChunkedLocalAttention(Attention): dtype = torch.get_default_dtype() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size else: kv_cache_dtype = "auto" - block_size = 16 - underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size - ) + underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype) attn_backend = create_chunked_local_attention_backend( - underlying_attn_backend, attention_chunk_size, block_size + underlying_attn_backend, attention_chunk_size ) super().__init__( diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index 9333b35e6..5bd8e163f 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -188,10 +188,8 @@ class CrossAttention(Attention): if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size else: kv_cache_dtype = "auto" - block_size = 16 if attn_type is not None: assert attn_type == AttentionType.ENCODER_DECODER, ( @@ -202,7 +200,6 @@ class CrossAttention(Attention): head_size, dtype, kv_cache_dtype, - block_size, attn_type=AttentionType.ENCODER_DECODER, ) attn_backend = create_cross_attention_backend(underlying_attn_backend) diff --git a/vllm/model_executor/layers/attention/encoder_only_attention.py b/vllm/model_executor/layers/attention/encoder_only_attention.py index 941911028..0897ee45b 100644 --- a/vllm/model_executor/layers/attention/encoder_only_attention.py +++ b/vllm/model_executor/layers/attention/encoder_only_attention.py @@ -66,16 +66,13 @@ class EncoderOnlyAttention(Attention): if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size else: kv_cache_dtype = "auto" - block_size = 16 underlying_attn_backend = get_attn_backend( head_size, dtype, kv_cache_dtype, - block_size, attn_type=AttentionType.ENCODER_ONLY, ) diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 97ae3ef1b..b1dc1a860 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -323,11 +323,9 @@ class MLAAttention(nn.Module, AttentionLayerBase): if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" - block_size = 16 calculate_kv_scales = False self.quant_config = quant_config @@ -336,7 +334,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): self.head_size, dtype, kv_cache_dtype, - block_size, use_mla=True, use_sparse=use_sparse, num_heads=self.num_heads, @@ -449,17 +446,24 @@ class MLAAttention(nn.Module, AttentionLayerBase): ) # Attributes for forward_impl method - self.chunked_prefill_workspace_size = ( - MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( - get_current_vllm_config() - ) - ) + self._vllm_config = get_current_vllm_config() + self._chunked_prefill_workspace_size: int | None = None self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( static=True, group_shape=GroupShape.PER_TENSOR, compile_native=True, ) + @property + def chunked_prefill_workspace_size(self) -> int: + if self._chunked_prefill_workspace_size is None: + self._chunked_prefill_workspace_size = ( + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + self._vllm_config + ) + ) + return self._chunked_prefill_workspace_size + def forward( self, q: torch.Tensor, diff --git a/vllm/model_executor/layers/attention/static_sink_attention.py b/vllm/model_executor/layers/attention/static_sink_attention.py index 49d83823b..fe8dc7e34 100644 --- a/vllm/model_executor/layers/attention/static_sink_attention.py +++ b/vllm/model_executor/layers/attention/static_sink_attention.py @@ -126,17 +126,13 @@ class StaticSinkAttention(Attention, CustomOp): if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size else: kv_cache_dtype = "auto" - block_size = 16 if attn_backend is not None: underlying_attn_backend = attn_backend else: - underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size - ) + underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype) attn_backend = create_static_sink_attention_backend( underlying_attn_backend, # type: ignore[arg-type] sink_len=sink_len, @@ -153,7 +149,6 @@ class StaticSinkAttention(Attention, CustomOp): CustomOp.__init__(self) self.sink_len = sink_len - self.block_size = block_size self.sink_populated = False self.sink_key = None self.sink_value = None @@ -212,12 +207,12 @@ class StaticSinkAttention(Attention, CustomOp): def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: # Block size may get updated after model loading, refresh it - block_size = vllm_config.cache_config.block_size + self.block_size = vllm_config.cache_config.block_size # Should not be called for enc-dec or encoder-only attention. assert self.attn_type == AttentionType.DECODER return SinkFullAttentionSpec( - block_size=block_size, + block_size=self.block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, head_size_v=self.head_size_v, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 0e35bedbc..b76168281 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -217,10 +217,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token ) - # override attention block size if either (a) the - # user has not set it or (b) the user has set it - # too small. - if cache_config.block_size is None or cache_config.block_size < attn_block_size: + # override attention block size if it is too small, + # even if the user has explicitly set it + if cache_config.block_size < attn_block_size: cache_config.block_size = attn_block_size logger.info( "Setting attention block size to %d tokens " diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py index 4bffd7d7b..6774ea11d 100644 --- a/vllm/model_executor/models/whisper_causal.py +++ b/vllm/model_executor/models/whisper_causal.py @@ -290,16 +290,13 @@ class WhisperCausalAttentionWithBlockPooling(Attention): if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size else: kv_cache_dtype = "auto" - block_size = 16 underlying_attn_backend = get_attn_backend( head_size, dtype, kv_cache_dtype, - block_size, attn_type=attn_type, ) attn_backend = create_whisper_attention_backend_with_block_pooling( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 421cf8797..a35cc0be4 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -185,7 +185,7 @@ class CpuPlatform(Platform): cache_config = vllm_config.cache_config - if cache_config.block_size is None: + if not cache_config.user_specified_block_size: cache_config.block_size = 128 if cache_config.block_size % 32 != 0: @@ -361,6 +361,12 @@ class CpuPlatform(Platform): vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) + @classmethod + def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: + # TODO: CPU still sets block_size in check_and_update_config. + # Move that logic here so block_size is chosen by the backend. + pass + @classmethod def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]: assert platform.system() == "Linux" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 651cf86b1..2025c41ab 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -166,122 +166,12 @@ class CudaPlatformBase(Platform): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: - from vllm.v1.attention.backends.registry import AttentionBackendEnum - parallel_config = vllm_config.parallel_config model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" - cache_config = vllm_config.cache_config - if cache_config and cache_config.block_size is None: - cache_config.block_size = 16 - - # TODO(lucas): handle this more gracefully - # Note: model_config may be None during testing - # Note: block_size is initialized in - # HybridAttentionMambaModelConfig.verify_and_update_config - # for models with both attention and mamba, - # and doesn't need to be reinitialized here - if ( - model_config is not None - and model_config.use_mla - and cache_config.block_size is not None - ): - use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") - # If `--attention-config.backend` is not set and we are using MLA, - # then we default to FlashMLA backend for non-blackwell GPUs, - # else we default to CutlassMLA. For each case, we force the - # required block_size. - use_flashmla = False - use_cutlass_mla = False - use_flashinfer_mla = False - use_flashmla_sparse = False - use_flashinfer_mla_sparse = False - - from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported - - if vllm_config.attention_config.backend is None: - # Default case - hf_text_config = model_config.hf_text_config - qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) - if ( - cls.is_device_capability_family(100) - and not use_sparse - and qk_nope_head_dim == 128 - ): - # Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2) - # and only if qk_nope_head_dim == 128 (kernel constraint) - use_flashinfer_mla = True - # Set the backend in AttentionConfig so it's used during - # backend selection - vllm_config.attention_config.backend = ( - AttentionBackendEnum.FLASHINFER_MLA - ) - elif cls.is_device_capability_family(100) and not use_sparse: - # Fall back to CUTLASS_MLA as 2nd priority on Blackwell - use_cutlass_mla = True - elif is_flashmla_dense_supported()[0]: - # Non-Blackwell with FlashMLA support - use_flashmla = True - else: - # Fallback: will use Triton MLA or other compatible backend - pass - else: - # Forced case - backend = vllm_config.attention_config.backend - use_flashmla = backend == AttentionBackendEnum.FLASHMLA - use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA - use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA - use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE - use_flashinfer_mla_sparse = ( - backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE - ) - - if ( - use_flashmla - and is_flashmla_dense_supported()[0] - and cache_config.block_size % 64 != 0 - ): - cache_config.block_size = 64 - logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") - - if use_cutlass_mla and cache_config.block_size % 128 != 0: - cache_config.block_size = 128 - logger.info( - "Forcing kv cache block size to 128 for CUTLASS_MLA backend." - ) - - if ( - use_flashinfer_mla - and cache_config.block_size != 32 - and cache_config.block_size % 64 != 0 - ): - cache_config.block_size = 64 - logger.info( - "Forcing kv cache block size to 64 for FlashInferMLA backend." - ) - - if use_sparse: - if not (use_flashmla_sparse or use_flashinfer_mla_sparse): - use_flashmla_sparse = True - - if use_flashmla_sparse and cache_config.block_size != 64: - cache_config.block_size = 64 - logger.info( - "Forcing kv cache block size to 64 for FlashMLASparse backend." - ) - elif use_flashinfer_mla_sparse and cache_config.block_size not in ( - 32, - 64, - ): - cache_config.block_size = 64 - logger.info( - "Forcing kv cache block size to 64 for FlashInferMLASparse " - "backend." - ) - scheduler_config = vllm_config.scheduler_config # Note: model_config may be None during testing if ( @@ -312,10 +202,10 @@ class CudaPlatformBase(Platform): num_heads: int | None = None, ) -> tuple[ list[tuple["AttentionBackendEnum", int]], - dict["AttentionBackendEnum", list[str]], + dict["AttentionBackendEnum", tuple[int, list[str]]], ]: valid_backends_priorities = [] - invalid_reasons = {} + invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {} backend_priorities = _get_backend_priorities( attn_selector_config.use_mla, @@ -332,7 +222,7 @@ class CudaPlatformBase(Platform): except ImportError: invalid_reasons_i = ["ImportError"] if invalid_reasons_i: - invalid_reasons[backend] = invalid_reasons_i + invalid_reasons[backend] = (priority, invalid_reasons_i) else: valid_backends_priorities.append((backend, priority)) @@ -341,14 +231,13 @@ class CudaPlatformBase(Platform): @classmethod def get_attn_backend_cls( cls, - selected_backend: "AttentionBackendEnum", + selected_backend: "AttentionBackendEnum | None", attn_selector_config: "AttentionSelectorConfig", num_heads: int | None = None, ) -> str: device_capability = cls.get_device_capability() assert device_capability is not None - attn_selector_config = attn_selector_config._replace(block_size=None) # First try checking just the selected backend, if there is one. if selected_backend is not None: try: @@ -370,7 +259,7 @@ class CudaPlatformBase(Platform): # No selected backend or the selected backend is invalid, # so we try finding a valid backend. - valid_backends_priorities, invalid_reasons = cls.get_valid_backends( + valid_backends_priorities, all_invalid_reasons = cls.get_valid_backends( device_capability=device_capability, attn_selector_config=attn_selector_config, num_heads=num_heads, @@ -379,7 +268,7 @@ class CudaPlatformBase(Platform): "{" + ", ".join( f"{backend.name}: [{', '.join(reasons)}]" - for backend, reasons in invalid_reasons.items() + for backend, (_, reasons) in all_invalid_reasons.items() ) + "}" ) @@ -402,6 +291,29 @@ class CudaPlatformBase(Platform): ) selected_index = sorted_indices[0] selected_backend = valid_backends_priorities[selected_index][0] + selected_priority = valid_backends_priorities[selected_index][1] + + # If the user specified --block-size (but not --attention-backend), + # check whether that constraint precluded any higher-priority backends. + if attn_selector_config.block_size is not None: + excluded = [ + backend + for backend, (priority, reasons) in all_invalid_reasons.items() + if priority < selected_priority + and reasons == ["block_size not supported"] + ] + if excluded: + names = ", ".join(b.name for b in excluded) + logger.warning( + "--block-size %d precluded higher-priority backend(s) " + "%s. Using %s instead, which may result in reduced " + "performance. Consider removing --block-size to " + "auto-select the optimal block size.", + attn_selector_config.block_size, + names, + selected_backend.name, + ) + logger.info_once( "Using %s attention backend out of potential backends: %s.", selected_backend.name, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 3b56001ed..774d9e071 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -420,6 +420,56 @@ class Platform: """ pass + @classmethod + def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: + """ + Ensure block_size is compatible with the attention backend. + """ + from vllm.config.cache import CacheConfig + + cache_config = vllm_config.cache_config + if cache_config.user_specified_block_size: + # User specified --block-size; keep it. + return + + model_config = vllm_config.model_config + # model_config may be None during testing. + # Skip hybrid models — their block_size is managed by + # HybridAttentionMambaModelConfig. + if model_config is None or model_config.is_hybrid: + cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + return + + from vllm.config.vllm import ( + get_layers_from_vllm_config, + set_current_vllm_config, + ) + from vllm.model_executor.layers.attention_layer_base import ( + AttentionLayerBase, + ) + + attn_layers = get_layers_from_vllm_config( + vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ) + if not attn_layers: + cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + return + + first_layer = next(iter(attn_layers.values())) + backend_cls = first_layer.get_attn_backend() + with set_current_vllm_config(vllm_config): + preferred = backend_cls.get_preferred_block_size( + CacheConfig.DEFAULT_BLOCK_SIZE + ) + if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: + logger.info( + "Setting kv cache block size to %d for %s backend.", + preferred, + backend_cls.get_name(), + ) + cache_config.block_size = preferred + @classmethod def verify_model_arch(cls, model_arch: str) -> None: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b4925d085..f1fd33318 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -687,7 +687,7 @@ class RocmPlatform(Platform): ) compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - if cache_config and cache_config.block_size is None: + if cache_config and not cache_config.user_specified_block_size: if ( envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER # NOTE: This block has been deprecated @@ -707,6 +707,12 @@ class RocmPlatform(Platform): if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" + @classmethod + def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: + # TODO: ROCm still sets block_size in check_and_update_config. + # Move that logic here so block_size is chosen by the backend. + pass + @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c06afcb69..893b5454f 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -162,7 +162,7 @@ class XPUPlatform(Platform): model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config # in V1(or with chunked prefill) block_size is 64 - if cache_config and cache_config.block_size is None: + if cache_config and not cache_config.user_specified_block_size: cache_config.block_size = 64 # lazy import to avoid circular import @@ -227,6 +227,12 @@ class XPUPlatform(Platform): # ref. https://openucx.readthedocs.io/en/master/faq.html os.environ["UCX_MEMTYPE_CACHE"] = "n" + @classmethod + def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: + # TODO: XPU still sets block_size in check_and_update_config. + # Move that logic here so block_size is chosen by the backend. + pass + @classmethod def support_hybrid_kv_cache(cls) -> bool: return True diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 3af817a2e..a5c145ee3 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, replace from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar import numpy as np import torch @@ -144,15 +144,9 @@ class AttentionBackend(ABC): @classmethod def supports_block_size(cls, block_size: int | None) -> bool: - from vllm.config.cache import BlockSize - if block_size is None: return True - valid_sizes = get_args(BlockSize) - if block_size not in valid_sizes: - return False - supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes() if not supported_kernel_block_sizes: return True @@ -167,6 +161,17 @@ class AttentionBackend(ABC): return True return False + @classmethod + def get_preferred_block_size(cls, default_block_size: int) -> int: + supported_sizes = cls.get_supported_kernel_block_sizes() + if not supported_sizes: + return default_block_size + + if cls.supports_block_size(default_block_size): + return default_block_size + + return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes) + @classmethod def is_mla(cls) -> bool: return False @@ -210,7 +215,7 @@ class AttentionBackend(ABC): head_size: int, dtype: torch.dtype, kv_cache_dtype: "CacheDType | None", - block_size: int, + block_size: int | None, use_mla: bool, has_sink: bool, use_sparse: bool, @@ -224,7 +229,7 @@ class AttentionBackend(ABC): head_size: int, dtype: torch.dtype, kv_cache_dtype: "CacheDType | None", - block_size: int, + block_size: int | None, use_mla: bool, has_sink: bool, use_sparse: bool, diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 33f896035..d2027f9a2 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -75,7 +75,7 @@ class FlashAttnMLABackend(MLACommonBackend): head_size: int, dtype: torch.dtype, kv_cache_dtype: CacheDType | None, - block_size: int, + block_size: int | None, use_mla: bool, has_sink: bool, use_sparse: bool, diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 58d4bec7c..102d5706b 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -69,7 +69,7 @@ class FlashInferMLABackend(MLACommonBackend): head_size: int, dtype: torch.dtype, kv_cache_dtype: CacheDType | None, - block_size: int, + block_size: int | None, use_mla: bool, has_sink: bool, use_sparse: bool, diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py index 34683d3f6..4aa65e357 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py @@ -106,7 +106,7 @@ class FlashInferMLASparseBackend(AttentionBackend): head_size: int, dtype: torch.dtype, kv_cache_dtype: CacheDType | None, - block_size: int, + block_size: int | None, use_mla: bool, has_sink: bool, use_sparse: bool, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 163b23b04..4720b2a03 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -80,7 +80,7 @@ class FlashMLABackend(MLACommonBackend): head_size: int, dtype: torch.dtype, kv_cache_dtype: CacheDType | None, - block_size: int, + block_size: int | None, use_mla: bool, has_sink: bool, use_sparse: bool, diff --git a/vllm/v1/attention/selector.py b/vllm/v1/attention/selector.py index 48a86655c..40cc10278 100644 --- a/vllm/v1/attention/selector.py +++ b/vllm/v1/attention/selector.py @@ -49,7 +49,6 @@ def get_attn_backend( head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, - block_size: int | None, use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, @@ -71,6 +70,12 @@ def get_attn_backend( vllm_config = get_current_vllm_config() + cache_config = vllm_config.cache_config + if cache_config is not None and cache_config.user_specified_block_size: + block_size = cache_config.block_size + else: + block_size = None + attn_selector_config = AttentionSelectorConfig( head_size=head_size, dtype=dtype, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4bbaafed3..c68ac66ad 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -122,7 +122,11 @@ class EngineCore: num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( vllm_config ) - + if kv_cache_config.kv_cache_groups: + vllm_config.cache_config.block_size = min( + g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups + ) + vllm_config.validate_block_size() vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index d2dfda9b8..95336034c 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -42,6 +42,7 @@ from vllm.distributed.parallel_state import ( ) from vllm.envs import enable_envs_cache from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.tracing import instrument, maybe_init_worker_tracer from vllm.utils.network_utils import ( get_distributed_init_method, @@ -617,6 +618,9 @@ class WorkerProc: ) self.worker.load_model() + # Set block size based on the attention backends + current_platform.update_block_size_for_backend(vllm_config) + # Initialize message queues after init_device() since multi-node setups # (nnodes_within_dp > 1) require distributed groups to be initialized self._init_message_queues(input_shm_handle, vllm_config) diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 11a0a38df..2e35faae8 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -387,6 +387,11 @@ class RayDistributedExecutor(Executor): self.collective_rpc("init_device") self.collective_rpc("load_model") + def _update_block_size(worker): + current_platform.update_block_size_for_backend(worker.vllm_config) + + self.collective_rpc(_update_block_size) + for pp_rank in range(self.parallel_config.pipeline_parallel_size): self.pp_tp_workers.append([]) for tp_rank in range(self.parallel_config.tensor_parallel_size): diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 3759c751c..a110596b7 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -12,6 +12,7 @@ import torch.distributed as dist import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.executor.abstract import Executor @@ -47,6 +48,7 @@ class UniProcExecutor(Executor): if not is_eep_new_worker: self.driver_worker.init_device() self.driver_worker.load_model() + current_platform.update_block_size_for_backend(self.vllm_config) def _distributed_args(self) -> tuple[str, int, int]: """Return (distributed_init_method, rank, local_rank).""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 08dbd614f..1283bf490 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -32,6 +32,7 @@ from vllm.config import ( set_current_vllm_config, update_config, ) +from vllm.config.cache import CacheConfig from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group @@ -586,6 +587,11 @@ class GPUModelRunner( custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = ( tuple(logits_processors) if logits_processors is not None else () ) + placeholder_block_size = ( + self.cache_config.block_size or CacheConfig.DEFAULT_BLOCK_SIZE + ) + self._init_block_sizes = [placeholder_block_size] + self._init_kernel_block_sizes = [placeholder_block_size] self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, # We need to use the encoder length for encoder-decoder @@ -595,8 +601,8 @@ class GPUModelRunner( device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), - block_sizes=[self.cache_config.block_size], - kernel_block_sizes=[self.cache_config.block_size], + block_sizes=[placeholder_block_size], + kernel_block_sizes=[placeholder_block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config, @@ -6112,8 +6118,10 @@ class GPUModelRunner( ) -> None: """ Re-initialize the input batch if the block sizes are different from - `[self.cache_config.block_size]`. This usually happens when there - are multiple KV cache groups. + what it was originally created with. This happens when the final + block size (determined after model loading) differs from the + placeholder used during __init__, or when there are multiple + KV cache groups. Args: kv_cache_config: The KV cache configuration. @@ -6138,14 +6146,17 @@ class GPUModelRunner( ) + kv_cache_group.kv_cache_spec.num_speculative_blocks max_num_blocks.append(max_num_blocks_per_req) - if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ - self.cache_config.block_size - ]: + if ( + block_sizes != self._init_block_sizes + or kernel_block_sizes != self._init_kernel_block_sizes + ): assert self.offload_config.uva.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 "for more details." ) + self._init_block_sizes = block_sizes + self._init_kernel_block_sizes = kernel_block_sizes self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=max_model_len, @@ -6162,6 +6173,15 @@ class GPUModelRunner( is_pooling_model=self.is_pooling_model, ) + assert self._init_block_sizes == block_sizes, ( + f"InputBatch block_sizes {self._init_block_sizes} != " + f"kv_cache block_sizes {block_sizes}" + ) + assert self._init_kernel_block_sizes == kernel_block_sizes, ( + f"InputBatch kernel_block_sizes {self._init_kernel_block_sizes} " + f"!= kv_cache kernel_block_sizes {kernel_block_sizes}" + ) + def _allocate_kv_cache_tensors( self, kv_cache_config: KVCacheConfig ) -> dict[str, torch.Tensor]: