diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index c81a8fe09..8f7993647 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -13,7 +13,6 @@ import torch.nn as nn from PIL import Image from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config -from vllm.config.cache import CacheConfig from vllm.config.multimodal import ( AudioDummyOptions, BaseDummyOptions, @@ -132,9 +131,7 @@ def initialize_dummy_model( ): temp_file = tempfile.mkstemp()[1] current_device = torch.get_default_device() - vllm_config = VllmConfig( - model_config=model_config, cache_config=CacheConfig(block_size=16) - ) + vllm_config = VllmConfig(model_config=model_config) with set_current_vllm_config(vllm_config=vllm_config): init_distributed_environment( world_size=1, diff --git a/tests/models/utils.py b/tests/models/utils.py index 8c1fb63d6..4830f18dc 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -457,9 +457,6 @@ def dummy_hf_overrides( # Kimi uses `num_expert_group` instead of `n_group`. if n_group is None: n_group = getattr(text_config, "num_expert_group", None) - # InternS1Pro uses `router_n_groups` instead of `n_group`. - if n_group is None: - n_group = getattr(text_config, "router_n_groups", None) num_experts = n_group * 2 if n_group is not None else 2 # we use three layers for Gemma-3n to check @@ -489,14 +486,12 @@ def dummy_hf_overrides( # Only set MoE related config when the model has MoE layers. # Otherwise all models detected as MoE by _get_transformers_backend_cls. if model_arch_config.num_experts > 0: - orig_topk = getattr(text_config, "num_experts_per_tok", 2) - topk = min(orig_topk, 2) update_dict.update( { "num_experts": num_experts, - "num_experts_per_tok": topk, + "num_experts_per_tok": 2, # Kimi uses `num_experts_per_token`. - "num_experts_per_token": topk, + "num_experts_per_token": 2, "num_local_experts": num_experts, # Otherwise there will not be any expert layers "first_k_dense_replace": 0, diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 65e97b7ad..8b180168d 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -78,7 +78,7 @@ def _create_proposer( device = current_platform.device_type vllm_config = VllmConfig( model_config=model_config, - cache_config=CacheConfig(block_size=16), + cache_config=CacheConfig(), speculative_config=speculative_config, device_config=DeviceConfig(device=device), parallel_config=ParallelConfig(), diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 313a4577b..daceaa6c2 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -19,6 +19,7 @@ else: logger = init_logger(__name__) +BlockSize = Literal[1, 8, 16, 32, 64, 128, 256] CacheDType = Literal[ "auto", "bfloat16", @@ -38,11 +39,13 @@ KVOffloadingBackend = Literal["native", "lmcache"] class CacheConfig: """Configuration for the KV cache.""" - block_size: SkipValidation[int] = None # type: ignore[assignment] - """Size of a contiguous cache block in number of tokens. + block_size: SkipValidation[BlockSize] = None # type: ignore[assignment] + """Size of a contiguous cache block in number of tokens. On CUDA devices, + only block sizes up to 32 are supported. - This is None until the platform sets it. Always an int by the time - the engine starts.""" + This config has no static default. If left unspecified by the user, it will + be set in `Platform.check_and_update_config()` based on the current + platform.""" gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) """The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index fffe769e7..e951e6f2c 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -915,6 +915,32 @@ class VllmConfig: ) current_platform.check_and_update_config(self) + # If DCP, ensure the block size is right. + if self.parallel_config.decode_context_parallel_size > 1: + if self.parallel_config.dcp_kv_cache_interleave_size > 1 and ( + self.parallel_config.cp_kv_cache_interleave_size + != self.parallel_config.dcp_kv_cache_interleave_size + ): + self.parallel_config.cp_kv_cache_interleave_size = ( + self.parallel_config.dcp_kv_cache_interleave_size + ) + logger.warning_once( + "cp_kv_cache_interleave_size is overridden by dcp_kv_cache" + "_interleave_size. And dcp-kv-cache-interleave-size will be " + "deprecated when PCP is fully supported." + ) + assert ( + self.parallel_config.cp_kv_cache_interleave_size + <= self.cache_config.block_size + and self.cache_config.block_size + % self.parallel_config.cp_kv_cache_interleave_size + == 0 + ), ( + f"Block_size({self.cache_config.block_size}) should be greater " + "than or equal to and divisible by cp_kv_cache_interleave_size " + f"({self.parallel_config.cp_kv_cache_interleave_size})." + ) + # Do this after all the updates to compilation_config.mode effective_dp_size = ( self.parallel_config.data_parallel_size @@ -1082,6 +1108,26 @@ class VllmConfig: # Default to enable HMA if not explicitly disabled by user or logic above. self.scheduler_config.disable_hybrid_kv_cache_manager = False + if self.cache_config.mamba_cache_mode == "align": + assert ( + self.cache_config.block_size + <= self.scheduler_config.max_num_batched_tokens + ), ( + "In Mamba cache align mode, block_size " + f"({self.cache_config.block_size}) must be <= " + "max_num_batched_tokens " + f"({self.scheduler_config.max_num_batched_tokens})." + ) + if self.scheduler_config.long_prefill_token_threshold > 0: + assert ( + self.scheduler_config.long_prefill_token_threshold + >= self.cache_config.block_size + ) + assert not self.scheduler_config.disable_chunked_mm_input, ( + "Chunked MM input is required because we need the flexibility to " + "schedule a multiple of block_size tokens even if they are in the " + "middle of a mm input" + ) if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path.absolute().expanduser() @@ -1442,57 +1488,6 @@ class VllmConfig: f"compilation_config={self.compilation_config!r}" ) - def validate_block_size(self) -> None: - """Validate block_size against DCP and mamba constraints. - - Called after Platform.update_block_size_for_backend() has - finalised block_size, so that the checks see the real value - rather than the initial None sentinel. - """ - block_size = self.cache_config.block_size - assert block_size is not None, ( - "validate_block_size called before block_size was set" - ) - - # DCP interleave-size compatibility - if self.parallel_config.decode_context_parallel_size > 1: - if self.parallel_config.dcp_kv_cache_interleave_size > 1 and ( - self.parallel_config.cp_kv_cache_interleave_size - != self.parallel_config.dcp_kv_cache_interleave_size - ): - self.parallel_config.cp_kv_cache_interleave_size = ( - self.parallel_config.dcp_kv_cache_interleave_size - ) - logger.warning_once( - "cp_kv_cache_interleave_size is overridden by dcp_kv_cache" - "_interleave_size. And dcp-kv-cache-interleave-size will be " - "deprecated when PCP is fully supported." - ) - assert ( - self.parallel_config.cp_kv_cache_interleave_size <= block_size - and block_size % self.parallel_config.cp_kv_cache_interleave_size == 0 - ), ( - f"Block_size({block_size}) should be greater " - "than or equal to and divisible by cp_kv_cache_interleave_size " - f"({self.parallel_config.cp_kv_cache_interleave_size})." - ) - - # Mamba cache align-mode constraints - if self.cache_config.mamba_cache_mode == "align": - assert block_size <= self.scheduler_config.max_num_batched_tokens, ( - "In Mamba cache align mode, block_size " - f"({block_size}) must be <= " - "max_num_batched_tokens " - f"({self.scheduler_config.max_num_batched_tokens})." - ) - if self.scheduler_config.long_prefill_token_threshold > 0: - assert self.scheduler_config.long_prefill_token_threshold >= block_size - assert not self.scheduler_config.disable_chunked_mm_input, ( - "Chunked MM input is required because we need the flexibility " - "to schedule a multiple of block_size tokens even if they are " - "in the middle of a mm input" - ) - @model_validator(mode="after") def validate_mamba_block_size(self) -> "VllmConfig": if self.model_config is None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1d9a924bd..8ea96de49 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -59,6 +59,7 @@ from vllm.config import ( get_attr_docs, ) from vllm.config.cache import ( + BlockSize, CacheDType, KVOffloadingBackend, MambaCacheMode, @@ -430,7 +431,7 @@ class EngineArgs: max_parallel_loading_workers: int | None = ( ParallelConfig.max_parallel_loading_workers ) - block_size: int = None # type: ignore[assignment] + block_size: BlockSize = CacheConfig.block_size enable_prefix_caching: bool | None = None prefix_caching_hash_algo: PrefixCachingHashAlgo = ( CacheConfig.prefix_caching_hash_algo diff --git a/vllm/model_executor/layers/attention/chunked_local_attention.py b/vllm/model_executor/layers/attention/chunked_local_attention.py index 522981820..e33733c0c 100644 --- a/vllm/model_executor/layers/attention/chunked_local_attention.py +++ b/vllm/model_executor/layers/attention/chunked_local_attention.py @@ -30,8 +30,9 @@ from vllm.v1.kv_cache_interface import ( def create_chunked_local_attention_backend( underlying_attn_backend: AttentionBackend, attention_chunk_size: int, + block_size: int, ) -> type[AttentionBackend]: - prefix = f"ChunkedLocalAttention_{attention_chunk_size}_" + prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" underlying_builder = underlying_attn_backend.get_builder_cls() assert issubclass(underlying_builder, AttentionMetadataBuilder) @@ -54,9 +55,7 @@ def create_chunked_local_attention_backend( fast_build: bool = False, ): cm, make_virtual_batches_block_table = make_local_attention_virtual_batches( - attention_chunk_size, - common_attn_metadata, - self.kv_cache_spec.block_size, + attention_chunk_size, common_attn_metadata, block_size ) metadata = super().build(common_prefix_len, cm, fast_build) metadata.make_virtual_batches_block_table = make_virtual_batches_block_table @@ -98,13 +97,13 @@ class ChunkedLocalAttention(Attention): block_size = cache_config.block_size else: kv_cache_dtype = "auto" - block_size = None + block_size = 16 underlying_attn_backend = get_attn_backend( head_size, dtype, kv_cache_dtype, block_size ) attn_backend = create_chunked_local_attention_backend( - underlying_attn_backend, attention_chunk_size + underlying_attn_backend, attention_chunk_size, block_size ) super().__init__( diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 4fe25b027..98ff02e9d 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -407,24 +407,17 @@ class MLAAttention(nn.Module, AttentionLayerBase): ) # Attributes for forward_impl method - self._vllm_config = get_current_vllm_config() - self._chunked_prefill_workspace_size: int | None = None + self.chunked_prefill_workspace_size = ( + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config() + ) + ) self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( static=True, group_shape=GroupShape.PER_TENSOR, compile_native=True, ) - @property - def chunked_prefill_workspace_size(self) -> int: - if self._chunked_prefill_workspace_size is None: - self._chunked_prefill_workspace_size = ( - MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( - self._vllm_config - ) - ) - return self._chunked_prefill_workspace_size - def forward( self, q: torch.Tensor, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 921054f73..c2fcde4ab 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -163,12 +163,122 @@ class CudaPlatformBase(Platform): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + from vllm.v1.attention.backends.registry import AttentionBackendEnum + parallel_config = vllm_config.parallel_config model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + + # TODO(lucas): handle this more gracefully + # Note: model_config may be None during testing + # Note: block_size is initialized in + # HybridAttentionMambaModelConfig.verify_and_update_config + # for models with both attention and mamba, + # and doesn't need to be reinitialized here + if ( + model_config is not None + and model_config.use_mla + and cache_config.block_size is not None + ): + use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") + # If `--attention-config.backend` is not set and we are using MLA, + # then we default to FlashMLA backend for non-blackwell GPUs, + # else we default to CutlassMLA. For each case, we force the + # required block_size. + use_flashmla = False + use_cutlass_mla = False + use_flashinfer_mla = False + use_flashmla_sparse = False + use_flashinfer_mla_sparse = False + + from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported + + if vllm_config.attention_config.backend is None: + # Default case + hf_text_config = model_config.hf_text_config + qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) + if ( + cls.is_device_capability_family(100) + and not use_sparse + and qk_nope_head_dim == 128 + ): + # Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2) + # and only if qk_nope_head_dim == 128 (kernel constraint) + use_flashinfer_mla = True + # Set the backend in AttentionConfig so it's used during + # backend selection + vllm_config.attention_config.backend = ( + AttentionBackendEnum.FLASHINFER_MLA + ) + elif cls.is_device_capability_family(100) and not use_sparse: + # Fall back to CUTLASS_MLA as 2nd priority on Blackwell + use_cutlass_mla = True + elif is_flashmla_dense_supported()[0]: + # Non-Blackwell with FlashMLA support + use_flashmla = True + else: + # Fallback: will use Triton MLA or other compatible backend + pass + else: + # Forced case + backend = vllm_config.attention_config.backend + use_flashmla = backend == AttentionBackendEnum.FLASHMLA + use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA + use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA + use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE + use_flashinfer_mla_sparse = ( + backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE + ) + + if ( + use_flashmla + and is_flashmla_dense_supported()[0] + and cache_config.block_size % 64 != 0 + ): + cache_config.block_size = 64 + logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") + + if use_cutlass_mla and cache_config.block_size % 128 != 0: + cache_config.block_size = 128 + logger.info( + "Forcing kv cache block size to 128 for CUTLASS_MLA backend." + ) + + if ( + use_flashinfer_mla + and cache_config.block_size != 32 + and cache_config.block_size % 64 != 0 + ): + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashInferMLA backend." + ) + + if use_sparse: + if not (use_flashmla_sparse or use_flashinfer_mla_sparse): + use_flashmla_sparse = True + + if use_flashmla_sparse and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashMLASparse backend." + ) + elif use_flashinfer_mla_sparse and cache_config.block_size not in ( + 32, + 64, + ): + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashInferMLASparse " + "backend." + ) + scheduler_config = vllm_config.scheduler_config # Note: model_config may be None during testing if ( @@ -183,49 +293,6 @@ class CudaPlatformBase(Platform): ) scheduler_config.disable_chunked_mm_input = True - @classmethod - def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: - cache_config = vllm_config.cache_config - if cache_config.block_size is not None: - # User specified --block-size; keep it. - return - - model_config = vllm_config.model_config - # model_config may be None during testing. - # Skip hybrid models — their block_size is managed by - # HybridAttentionMambaModelConfig. - if model_config is None or model_config.is_hybrid: - cache_config.block_size = 16 - return - - from vllm.config.vllm import ( - get_layers_from_vllm_config, - set_current_vllm_config, - ) - from vllm.model_executor.layers.attention_layer_base import ( - AttentionLayerBase, - ) - - attn_layers = get_layers_from_vllm_config( - vllm_config, - AttentionLayerBase, - ) - if not attn_layers: - cache_config.block_size = 16 - return - - first_layer = next(iter(attn_layers.values())) - backend_cls = first_layer.get_attn_backend() - with set_current_vllm_config(vllm_config): - preferred = backend_cls.get_preferred_block_size(16) - if preferred != 16: - logger.info( - "Setting kv cache block size to %d for %s backend.", - preferred, - backend_cls.get_name(), - ) - cache_config.block_size = preferred - @classmethod def get_current_memory_usage( cls, device: torch.types.Device | None = None @@ -242,10 +309,10 @@ class CudaPlatformBase(Platform): num_heads: int | None = None, ) -> tuple[ list[tuple["AttentionBackendEnum", int]], - dict["AttentionBackendEnum", tuple[int, list[str]]], + dict["AttentionBackendEnum", list[str]], ]: valid_backends_priorities = [] - invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {} + invalid_reasons = {} backend_priorities = _get_backend_priorities( attn_selector_config.use_mla, @@ -262,155 +329,84 @@ class CudaPlatformBase(Platform): except ImportError: invalid_reasons_i = ["ImportError"] if invalid_reasons_i: - invalid_reasons[backend] = (priority, invalid_reasons_i) + invalid_reasons[backend] = invalid_reasons_i else: valid_backends_priorities.append((backend, priority)) return valid_backends_priorities, invalid_reasons - @classmethod - def select_attention_backend( - cls, - selected_backend: "AttentionBackendEnum | None", - attn_selector_config: "AttentionSelectorConfig", - device_capability: "DeviceCapability", - raise_on_invalid: bool = True, - num_heads: int | None = None, - ) -> "AttentionBackendEnum | None": - """Select the best attention backend for the given configuration. - - Args: - selected_backend: User-specified backend, or None for auto-selection - attn_selector_config: Configuration for attention selection - device_capability: Device capability info - raise_on_invalid: If True, raise ValueError when no valid backend - num_heads: Number of attention heads per GPU, used for backend - priority ordering on Blackwell GPUs - - Returns: - The selected backend enum, or None if no valid backend found - and raise_on_invalid is False - """ - # First try checking just the selected backend, if there is one. - if selected_backend is not None: - try: - backend_class = selected_backend.get_class() - validation_errors = backend_class.validate_configuration( - device_capability=device_capability, - **attn_selector_config._asdict(), - ) - except ImportError: - validation_errors = ["ImportError"] - if validation_errors: - if raise_on_invalid: - raise ValueError( - f"Selected backend {selected_backend} is not valid for " - f"this configuration. Reason: {validation_errors}" - ) - return None - return selected_backend - - # No selected backend, so find the best valid one. - valid_backends_priorities, invalid_reasons = cls.get_valid_backends( - device_capability=device_capability, - attn_selector_config=attn_selector_config, - num_heads=num_heads, - ) - - if len(valid_backends_priorities) == 0: - if raise_on_invalid: - reasons_str = ( - "{" - + ", ".join( - f"{backend.name}: [{', '.join(reasons)}]" - for backend, (_, reasons) in invalid_reasons.items() - ) - + "}" - ) - config_str = attn_selector_config.__repr__() - raise ValueError( - f"No valid attention backend found for {cls.device_name} " - f"with {config_str}. Reasons: {reasons_str}." - ) - return None - - # Select the one with the highest priority (lowest index). - sorted_backends = sorted(valid_backends_priorities, key=lambda x: x[1]) - chosen_backend, chosen_priority = sorted_backends[0] - - # If the user specified --block-size (but not --attention-backend), - # check whether that constraint precluded any higher-priority backends. - if attn_selector_config.block_size is not None: - excluded = [ - backend - for backend, (priority, reasons) in invalid_reasons.items() - if priority < chosen_priority - and reasons == ["block_size not supported"] - ] - if excluded: - names = ", ".join(b.name for b in excluded) - logger.warning( - "--block-size %d excluded higher-priority backend(s) " - "%s. Using %s instead, which may result in reduced " - "performance. Consider removing --block-size to " - "auto-select the optimal block size.", - attn_selector_config.block_size, - names, - chosen_backend.name, - ) - - return chosen_backend - @classmethod def get_attn_backend_cls( cls, - selected_backend: "AttentionBackendEnum | None", + selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", num_heads: int | None = None, ) -> str: device_capability = cls.get_device_capability() assert device_capability is not None - chosen_backend = cls.select_attention_backend( - selected_backend=selected_backend, + attn_selector_config = attn_selector_config._replace(block_size=None) + # First try checking just the selected backend, if there is one. + if selected_backend is not None: + try: + backend_class = selected_backend.get_class() + invalid_reasons = backend_class.validate_configuration( + device_capability=device_capability, + **attn_selector_config._asdict(), + ) + except ImportError: + invalid_reasons = ["ImportError"] + if invalid_reasons: + raise ValueError( + f"Selected backend {selected_backend} is not valid for " + f"this configuration. Reason: {invalid_reasons}" + ) + else: + logger.info("Using %s backend.", selected_backend) + return selected_backend.get_path() + + # No selected backend or the selected backend is invalid, + # so we try finding a valid backend. + valid_backends_priorities, invalid_reasons = cls.get_valid_backends( + device_capability=device_capability, attn_selector_config=attn_selector_config, num_heads=num_heads, - device_capability=device_capability, - raise_on_invalid=True, ) - assert chosen_backend is not None # raise_on_invalid=True guarantees this - - # Log the selection - if selected_backend is not None: - logger.info("Using %s backend.", chosen_backend) - else: - # Get all valid backends for logging - valid_backends_priorities, invalid_reasons = cls.get_valid_backends( - device_capability=device_capability, - attn_selector_config=attn_selector_config, - num_heads=num_heads, + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() ) - reasons_str = ( - "{" - + ", ".join( - f"{backend.name}: [{', '.join(reasons)}]" - for backend, (_, reasons) in invalid_reasons.items() - ) - + "}" - ) - config_str = attn_selector_config.__repr__() - logger.debug_once( - f"Some attention backends are not valid for {cls.device_name} with " - f"{config_str}. Reasons: {reasons_str}." - ) - logger.info_once( - "Using %s attention backend out of potential backends: %s", - chosen_backend.name, - tuple(backend.name for backend, _ in valid_backends_priorities), - scope="local", + + "}" + ) + config_str = attn_selector_config.__repr__() + logger.debug_once( + f"Some attention backends are not valid for {cls.device_name} with " + f"{config_str}. Reasons: {reasons_str}." + ) + if len(valid_backends_priorities) == 0: + raise ValueError( + f"No valid attention backend found for {cls.device_name} " + f"with {config_str}. Reasons: {reasons_str}." ) - return chosen_backend.get_path() + # We have found some valid backends. Select the one with the + # highest priority. + sorted_indices = sorted( + range(len(valid_backends_priorities)), + key=lambda i: valid_backends_priorities[i][1], + ) + selected_index = sorted_indices[0] + selected_backend = valid_backends_priorities[selected_index][0] + logger.info_once( + "Using %s attention backend out of potential backends: %s.", + selected_backend.name, + "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]", + scope="local", + ) + + return selected_backend.get_path() @classmethod def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index ba44fa6d9..6794c05f5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -406,13 +406,6 @@ class Platform: """ pass - @classmethod - def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: - """ - Ensure block_size is compatible with the attention backend. - """ - pass - @classmethod def verify_model_arch(cls, model_arch: str) -> None: """ diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index f31e2635a..9c004d772 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, replace from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args import numpy as np import torch @@ -144,9 +144,15 @@ class AttentionBackend(ABC): @classmethod def supports_block_size(cls, block_size: int | None) -> bool: + from vllm.config.cache import BlockSize + if block_size is None: return True + valid_sizes = get_args(BlockSize) + if block_size not in valid_sizes: + return False + supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes() if not supported_kernel_block_sizes: return True @@ -161,17 +167,6 @@ class AttentionBackend(ABC): return True return False - @classmethod - def get_preferred_block_size(cls, default_block_size: int = 16) -> int: - supported_sizes = cls.get_supported_kernel_block_sizes() - if not supported_sizes: - return default_block_size - - if cls.supports_block_size(default_block_size): - return default_block_size - - return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes) - @classmethod def is_mla(cls) -> bool: return False diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b805abe8a..a258fe295 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -114,14 +114,7 @@ class EngineCore: num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( vllm_config ) - if kv_cache_config.kv_cache_groups: - vllm_config.cache_config.block_size = min( - g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups - ) - elif vllm_config.cache_config.block_size is None: - # Attention-free models (encoder-only, SSM) — use default. - vllm_config.cache_config.block_size = 16 - vllm_config.validate_block_size() + vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 9cc7dc63a..b63cbd658 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -41,7 +41,6 @@ from vllm.distributed.parallel_state import ( ) from vllm.envs import enable_envs_cache from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.tracing import instrument, maybe_init_worker_tracer from vllm.utils.network_utils import ( get_distributed_init_method, @@ -580,9 +579,6 @@ class WorkerProc: self._init_message_queues(input_shm_handle, vllm_config) self.worker.load_model() - # Set block size based on the attention backends - current_platform.update_block_size_for_backend(vllm_config) - # Enable environment variable cache (e.g. assume no more # environment variable overrides after this point) enable_envs_cache() diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 6c939a593..ad51526ae 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -385,11 +385,6 @@ class RayDistributedExecutor(Executor): self.collective_rpc("init_device") self.collective_rpc("load_model") - def _update_block_size(worker): - current_platform.update_block_size_for_backend(worker.vllm_config) - - self.collective_rpc(_update_block_size) - for pp_rank in range(self.parallel_config.pipeline_parallel_size): self.pp_tp_workers.append([]) for tp_rank in range(self.parallel_config.tensor_parallel_size): diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 290c4dc8b..b9c7b5501 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -12,7 +12,6 @@ import torch.distributed as dist import vllm.envs as envs from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType @@ -47,7 +46,6 @@ class UniProcExecutor(Executor): self.driver_worker.init_worker(all_kwargs=[kwargs]) self.driver_worker.init_device() self.driver_worker.load_model() - current_platform.update_block_size_for_backend(self.vllm_config) def _distributed_args(self) -> tuple[str, int, int]: """Return (distributed_init_method, rank, local_rank).""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 51c4f5805..9ef8584c7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -513,7 +513,6 @@ class GPUModelRunner( custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = ( tuple(logits_processors) if logits_processors is not None else () ) - placeholder_block_size = self.cache_config.block_size or 16 self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, # We need to use the encoder length for encoder-decoer @@ -523,8 +522,8 @@ class GPUModelRunner( device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), - block_sizes=[placeholder_block_size], - kernel_block_sizes=[placeholder_block_size], + block_sizes=[self.cache_config.block_size], + kernel_block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config,