diff --git a/vllm/config/cache.py b/vllm/config/cache.py index cf2977622..1734f6b15 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -5,7 +5,7 @@ import hashlib from dataclasses import field from typing import TYPE_CHECKING, Any, Literal -from pydantic import Field, SkipValidation, field_validator +from pydantic import Field, SkipValidation, field_validator, model_validator from pydantic.dataclasses import dataclass from vllm.config.utils import config @@ -90,8 +90,10 @@ class CacheConfig: mamba_page_size_padded: int | None = None """ Optional override for mamba page size; used by hybrid mamba/attention models to ensure exact alignment with attention page size.""" - mamba_block_size: int | None = None - """Size of a contiguous cache block in number of tokens for mamba cache.""" + mamba_block_size: int | None = Field(default=None, gt=0) + """Size of a contiguous cache block in number of tokens for mamba cache. + Can be set only when prefix caching is enabled. + Value must be a multiple of 8 to align with causal_conv1d kernel.""" mamba_cache_dtype: MambaDType = "auto" """The data type to use for the Mamba cache (both the conv as well as the ssm state). If set to 'auto', the data type will be inferred from the model @@ -183,3 +185,11 @@ class CacheConfig: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: logger.warning("Possibly too large swap space. %s", msg) + + @model_validator(mode="after") + def validate_mamba_block_size(self) -> "CacheConfig": + if self.mamba_block_size is not None and not self.enable_prefix_caching: + raise ValueError( + "--mamba-block-size can only be set with --enable-prefix-caching" + ) + return self diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 24f9d18dc..ede470b08 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -535,6 +535,7 @@ class EngineArgs: calculate_kv_scales: bool = CacheConfig.calculate_kv_scales mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype + mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size") additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") @@ -893,6 +894,9 @@ class EngineArgs: cache_group.add_argument( "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"] ) + cache_group.add_argument( + "--mamba-block-size", **cache_kwargs["mamba_block_size"] + ) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -1390,6 +1394,7 @@ class EngineArgs: kv_sharing_fast_prefill=self.kv_sharing_fast_prefill, mamba_cache_dtype=self.mamba_cache_dtype, mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, + mamba_block_size=self.mamba_block_size, ) ray_runtime_env = None diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 493b74bdd..ac5949cda 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -291,9 +291,8 @@ class MambaModelConfig(VerifyAndUpdateConfig): model_config = vllm_config.model_config cache_config = vllm_config.cache_config - # Set mamba block size to max_model_len (this may get - # override by prefix caching logic later) - cache_config.mamba_block_size = model_config.max_model_len + if cache_config.mamba_block_size is None: + cache_config.mamba_block_size = model_config.max_model_len if cache_config.enable_prefix_caching: if model_config.supports_mamba_prefix_caching: @@ -333,6 +332,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): if not envs.VLLM_USE_V1: return + # Save the user input before it gets modified by MambaModelConfig + mamba_block_size = vllm_config.cache_config.mamba_block_size # Enable FULL_AND_PIECEWISE by default MambaModelConfig.verify_and_update_config(vllm_config) @@ -386,7 +387,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): # With prefix caching, select attention block size to # optimize for mamba kernel performance - # mamba SSD kernel uses a chunk_size, e.g. 256 + # Mamba2 SSD kernel uses a chunk_size, e.g. 256 # Align the block to the kernel: use lowest multiple of chunk_size # of attention tokens that would fit mamba_page_size: # e.g. for mamba page size = 788kB @@ -404,7 +405,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): def lcm(a, b): return a * b // gcd(a, b) - base_chunk_size = model_config.get_mamba_chunk_size() + base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() + attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)