[HMA]Fix corner case when hybrid page_size can not be evenly divided issue (blk_size=64,tp=4) (#37467)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
Chendi.Xue
2026-03-30 11:47:30 -05:00
committed by GitHub
parent b4a2f3ac36
commit 3b1dbaad4e
12 changed files with 220 additions and 186 deletions

View File

@@ -851,6 +851,7 @@ def test_hybrid_attention_mamba_tensor_shapes():
vllm_ctx = vllm_config.compilation_config.static_forward_context
runner = GPUModelRunner(vllm_config, DEVICE)
current_platform.update_block_size_for_backend(vllm_config)
kv_cache_spec = runner.get_kv_cache_spec()
available_memory = 5 * GiB_bytes
@@ -1306,6 +1307,7 @@ def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
assert fwd_context is not None
runner = GPUModelRunner(vllm_config, DEVICE)
current_platform.update_block_size_for_backend(vllm_config)
kv_cache_spec = runner.get_kv_cache_spec()
available_memory = 5 * GiB_bytes

View File

@@ -38,6 +38,8 @@ class CacheConfig:
Accepts None (meaning "use default"). After construction, always int."""
user_specified_block_size: bool = field(default=False, init=False)
"""Whether block_size was explicitly provided. Derived automatically."""
user_specified_mamba_block_size: bool = field(default=False, init=False)
"""Whether mamba_block_size was explicitly provided. Derived automatically."""
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
@@ -182,6 +184,7 @@ class CacheConfig:
"cpu_kvcache_space_bytes",
"mamba_page_size_padded",
"user_specified_block_size",
"user_specified_mamba_block_size",
"_block_size_resolved",
# Post-init/derived counters
"num_gpu_blocks",
@@ -214,6 +217,8 @@ class CacheConfig:
object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE)
else:
object.__setattr__(self, "user_specified_block_size", True)
if self.mamba_block_size is not None:
object.__setattr__(self, "user_specified_mamba_block_size", True)
return self
@field_validator("calculate_kv_scales", mode="after")

View File

@@ -1,15 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from math import lcm
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
@@ -104,11 +99,11 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Ensure that page size of attention layers is greater than or
equal to the mamba layers. If not, automatically set the attention
block size to ensure that it is. If the attention page size is
strictly greater than the mamba page size, we pad the mamba page size
to make them equal.
Perform early validation and setup for hybrid attention/mamba models.
Block size alignment with mamba page sizes is handled later by
Platform.update_block_size_for_backend(), which runs after model
layers are constructed and the attention backend is known.
Args:
vllm_config: vLLM Config
@@ -118,6 +113,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# Disable calculate_kv_scales for hybrid models: uninitialized
# recurrent state corrupts scales during the calibration pass.
# See issue: https://github.com/vllm-project/vllm/issues/37554
if cache_config.calculate_kv_scales:
logger.warning(
"Disabling calculate_kv_scales for hybrid model '%s'. "
@@ -129,140 +125,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
)
cache_config.calculate_kv_scales = False
# Save the user input before it gets modified by MambaModelConfig
mamba_block_size = cache_config.mamba_block_size
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
attention_config = vllm_config.attention_config
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
# * Other MLA backends: kernel_block_size 64 alignment
if model_config.use_mla:
use_cutlass_mla = (
attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
)
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
attn_page_size_1_token = MLAAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
else:
kernel_block_alignment_size = 16
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
# get mamba page size
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=-1, # block_size doesn't matter for mamba page size
).page_size_bytes
# Model may be marked as is_hybrid
# but mamba is skipped via config,
# return directly
if mamba_page_size == 0:
return
if cache_config.mamba_cache_mode == "all":
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
# Align the block to the kernel: use lowest multiple of chunk_size
# of attention tokens that would fit mamba_page_size:
# e.g. for mamba page size = 788kB
# attn_1_token = 2kB -> fits ~394 tokens
# then round up to a multiple of 256 -> 512 tokens
# End result:
# attn_block_size = 512
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
# TODO(tdoublep): this constraint can be relaxed fairly
# easily by changing the way we layout chunks in the
# mamba2 kernels.
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
# Without prefix caching, select minimum valid attention block size
# to minimize mamba state padding
# Calculate minimum attention block size that satisfies both:
# 1. Backend alignment requirements (kernel_block_alignment_size)
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
attn_block_size = kernel_block_alignment_size * cdiv(
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
# override attention block size if it is too small,
# even if the user has explicitly set it
if cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens "
"to ensure that attention page size is >= mamba page size.",
attn_block_size,
)
# By default, mamba block size will be set to max_model_len.
# When enabling prefix caching and using align mamba cache
# mode, we align mamba block size to the block size as the
# basic granularity for prefix caching.
if cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = cache_config.block_size
# compute new attention page size
attn_page_size = cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
if attn_page_size == mamba_page_size:
# don't need to pad mamba page size
return
# pad mamba page size to exactly match attention
if (
cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size
):
cache_config.mamba_page_size_padded = attn_page_size
mamba_padding_pct = (
100 * (attn_page_size - mamba_page_size) / mamba_page_size
)
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.",
mamba_padding_pct,
)
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod

View File

@@ -21,6 +21,7 @@ if TYPE_CHECKING:
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.selector import AttentionSelectorConfig
else:
FlexibleArgumentParser = object
@@ -424,29 +425,11 @@ class Platform:
pass
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
"""
Ensure block_size is compatible with the attention backend.
"""
from vllm.config.cache import CacheConfig
cache_config = vllm_config.cache_config
if cache_config.user_specified_block_size:
# User specified --block-size; keep it.
return
model_config = vllm_config.model_config
# model_config may be None during testing.
# Skip hybrid models — their block_size is managed by
# HybridAttentionMambaModelConfig.
if model_config is None or model_config.is_hybrid:
cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE
return
from vllm.config.vllm import (
get_layers_from_vllm_config,
set_current_vllm_config,
)
def _find_non_ssm_backend(
cls, vllm_config: "VllmConfig"
) -> "type[AttentionBackend] | None":
"""Find the first non-SSM attention backend from model layers."""
from vllm.config.vllm import get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import (
AttentionLayerBase,
)
@@ -455,23 +438,181 @@ class Platform:
vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
if not attn_layers:
cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE
for layer in attn_layers.values():
b = layer.get_attn_backend()
if not b.is_ssm():
return b
return None
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
"""
Ensure block_size is compatible with the attention backend.
For hybrid models, also aligns block_size with mamba page sizes.
"""
from vllm.config.cache import CacheConfig
from vllm.config.vllm import set_current_vllm_config
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
# model_config may be None during testing.
if not model_config:
return
first_layer = next(iter(attn_layers.values()))
backend_cls = first_layer.get_attn_backend()
backend_cls = cls._find_non_ssm_backend(vllm_config)
if backend_cls is None:
return
# Phase 1: Pick block size from backend (skip if user set --block-size)
if not cache_config.user_specified_block_size:
with set_current_vllm_config(vllm_config):
preferred = backend_cls.get_preferred_block_size(
CacheConfig.DEFAULT_BLOCK_SIZE
)
if preferred != CacheConfig.DEFAULT_BLOCK_SIZE:
logger.info(
"Setting kv cache block size to %d for %s backend.",
preferred,
backend_cls.get_name(),
)
cache_config.block_size = preferred
# Phase 2: Align block/mamba sizes for hybrid models
# (may override user settings).
if model_config.is_hybrid:
cls._align_hybrid_block_size(vllm_config, backend_cls)
@classmethod
def _align_hybrid_block_size(
cls,
vllm_config: "VllmConfig",
backend_cls: "type[AttentionBackend]",
) -> None:
"""
For hybrid attention/mamba models, ensure that the attention page
size is >= the mamba page size, and pad the mamba page size to match.
"""
from math import lcm
from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.models import ModelRegistry
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backend import MultipleOf
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
MambaSpec,
MLAAttentionSpec,
)
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Compute attention page size for 1 token
if model_config.use_mla:
attn_page_size_1_token = MLAAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
else:
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
# Compute mamba page size
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=-1,
).page_size_bytes
if mamba_page_size == 0:
return
# mamba_block_size here should either be user specified value or None
mamba_block_size = (
cache_config.mamba_block_size
if cache_config.user_specified_mamba_block_size
else None
)
# Get kernel block alignment from the backend's supported sizes
with set_current_vllm_config(vllm_config):
preferred = backend_cls.get_preferred_block_size(
CacheConfig.DEFAULT_BLOCK_SIZE
kernel_block_alignment_size = max(
min(
s.base if isinstance(s, MultipleOf) else s
for s in backend_cls.get_supported_kernel_block_sizes()
),
cache_config.block_size,
)
if preferred != CacheConfig.DEFAULT_BLOCK_SIZE:
if cache_config.mamba_cache_mode == "all":
# With prefix caching, align to mamba chunk size for kernel perf
# TODO(tdoublep): this constraint can be relaxed fairly
# easily by changing the way we layout chunks in the
# mamba2 kernels.
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
assert base_chunk_size is not None
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
# Without prefix caching, use minimum block size that satisfies
# both backend alignment and mamba page size compatibility
attn_block_size = kernel_block_alignment_size * cdiv(
mamba_page_size,
kernel_block_alignment_size * attn_page_size_1_token,
)
if cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting kv cache block size to %d for %s backend.",
preferred,
backend_cls.get_name(),
"Setting attention block size to %d tokens "
"to ensure that attention page size is >= mamba page size.",
attn_block_size,
)
if cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = cache_config.block_size
# Pad mamba page size to exactly match attention page size
attn_page_size = cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
if attn_page_size == mamba_page_size:
return
if (
cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size
):
cache_config.mamba_page_size_padded = attn_page_size
mamba_padding_pct = (
100 * (attn_page_size - mamba_page_size) / mamba_page_size
)
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.",
mamba_padding_pct,
)
cache_config.block_size = preferred
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:

View File

@@ -160,11 +160,7 @@ class XPUPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
parallel_config = vllm_config.parallel_config
# in V1(or with chunked prefill) block_size is 64
if cache_config and not cache_config.user_specified_block_size:
cache_config.block_size = 64
# lazy import to avoid circular import
from vllm.config import CUDAGraphMode
@@ -221,12 +217,6 @@ class XPUPlatform(Platform):
# ref. https://openucx.readthedocs.io/en/master/faq.html
os.environ["UCX_MEMTYPE_CACHE"] = "n"
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
# TODO: XPU still sets block_size in check_and_update_config.
# Move that logic here so block_size is chosen by the backend.
pass
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True

View File

@@ -311,6 +311,10 @@ class AttentionBackend(ABC):
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return None
@classmethod
def is_ssm(cls) -> bool:
return False
class AttentionMetadata:
pass

View File

@@ -43,6 +43,7 @@ from vllm.config import (
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv, round_up
from vllm.v1.attention.backend import (
@@ -90,6 +91,12 @@ class FlashAttentionBackend(AttentionBackend):
forward_includes_kv_cache_update: bool = False
@classmethod
def get_preferred_block_size(cls, default_block_size: int) -> int:
if current_platform.is_xpu():
return max(default_block_size, 64)
return super().get_preferred_block_size(default_block_size)
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"

View File

@@ -31,6 +31,10 @@ class GDNAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder
@classmethod
def is_ssm(cls) -> bool:
return True
@dataclass
class GDNAttentionMetadata:

View File

@@ -27,6 +27,10 @@ class LinearAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
return LinearAttentionMetadataBuilder
@classmethod
def is_ssm(cls) -> bool:
return True
@dataclass
class LinearAttentionMetadata:

View File

@@ -20,6 +20,10 @@ class Mamba1AttentionBackend(AttentionBackend):
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
return Mamba1AttentionMetadataBuilder
@classmethod
def is_ssm(cls) -> bool:
return True
@dataclass
class Mamba1AttentionMetadata(BaseMambaAttentionMetadata):

View File

@@ -96,6 +96,10 @@ class Mamba2AttentionBackend(AttentionBackend):
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
return Mamba2AttentionMetadataBuilder
@classmethod
def is_ssm(cls) -> bool:
return True
@dataclass
class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):

View File

@@ -18,6 +18,10 @@ class ShortConvAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
return ShortConvAttentionMetadataBuilder
@classmethod
def is_ssm(cls) -> bool:
return True
@dataclass
class ShortConvAttentionMetadata(BaseMambaAttentionMetadata):