[HMA]Fix corner case when hybrid page_size can not be evenly divided issue (blk_size=64,tp=4) (#37467)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -1,15 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from copy import deepcopy
|
||||
from math import lcm
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@@ -104,11 +99,11 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Ensure that page size of attention layers is greater than or
|
||||
equal to the mamba layers. If not, automatically set the attention
|
||||
block size to ensure that it is. If the attention page size is
|
||||
strictly greater than the mamba page size, we pad the mamba page size
|
||||
to make them equal.
|
||||
Perform early validation and setup for hybrid attention/mamba models.
|
||||
|
||||
Block size alignment with mamba page sizes is handled later by
|
||||
Platform.update_block_size_for_backend(), which runs after model
|
||||
layers are constructed and the attention backend is known.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM Config
|
||||
@@ -118,6 +113,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# Disable calculate_kv_scales for hybrid models: uninitialized
|
||||
# recurrent state corrupts scales during the calibration pass.
|
||||
# See issue: https://github.com/vllm-project/vllm/issues/37554
|
||||
|
||||
if cache_config.calculate_kv_scales:
|
||||
logger.warning(
|
||||
"Disabling calculate_kv_scales for hybrid model '%s'. "
|
||||
@@ -129,140 +125,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
)
|
||||
cache_config.calculate_kv_scales = False
|
||||
|
||||
# Save the user input before it gets modified by MambaModelConfig
|
||||
mamba_block_size = cache_config.mamba_block_size
|
||||
# Enable FULL_AND_PIECEWISE by default
|
||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||
|
||||
attention_config = vllm_config.attention_config
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# get attention page size (for 1 token)
|
||||
# Attention backend constraints:
|
||||
# - FlashAttention (FA) requires block size to be multiple of 16
|
||||
# - MLA (Multi-head Latent Attention) requires larger alignment:
|
||||
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
|
||||
# * Other MLA backends: kernel_block_size 64 alignment
|
||||
if model_config.use_mla:
|
||||
use_cutlass_mla = (
|
||||
attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
|
||||
)
|
||||
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
||||
attn_page_size_1_token = MLAAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
else:
|
||||
kernel_block_alignment_size = 16
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# get mamba page size
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=-1, # block_size doesn't matter for mamba page size
|
||||
).page_size_bytes
|
||||
|
||||
# Model may be marked as is_hybrid
|
||||
# but mamba is skipped via config,
|
||||
# return directly
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
|
||||
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
|
||||
# Align the block to the kernel: use lowest multiple of chunk_size
|
||||
# of attention tokens that would fit mamba_page_size:
|
||||
# e.g. for mamba page size = 788kB
|
||||
# attn_1_token = 2kB -> fits ~394 tokens
|
||||
# then round up to a multiple of 256 -> 512 tokens
|
||||
# End result:
|
||||
# attn_block_size = 512
|
||||
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
|
||||
# TODO(tdoublep): this constraint can be relaxed fairly
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
|
||||
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
||||
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
else:
|
||||
# Without prefix caching, select minimum valid attention block size
|
||||
# to minimize mamba state padding
|
||||
|
||||
# Calculate minimum attention block size that satisfies both:
|
||||
# 1. Backend alignment requirements (kernel_block_alignment_size)
|
||||
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
|
||||
attn_block_size = kernel_block_alignment_size * cdiv(
|
||||
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
|
||||
)
|
||||
|
||||
# override attention block size if it is too small,
|
||||
# even if the user has explicitly set it
|
||||
if cache_config.block_size < attn_block_size:
|
||||
cache_config.block_size = attn_block_size
|
||||
logger.info(
|
||||
"Setting attention block size to %d tokens "
|
||||
"to ensure that attention page size is >= mamba page size.",
|
||||
attn_block_size,
|
||||
)
|
||||
|
||||
# By default, mamba block size will be set to max_model_len.
|
||||
# When enabling prefix caching and using align mamba cache
|
||||
# mode, we align mamba block size to the block size as the
|
||||
# basic granularity for prefix caching.
|
||||
if cache_config.mamba_cache_mode == "align":
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
assert attn_page_size >= mamba_page_size
|
||||
|
||||
if attn_page_size == mamba_page_size:
|
||||
# don't need to pad mamba page size
|
||||
return
|
||||
|
||||
# pad mamba page size to exactly match attention
|
||||
if (
|
||||
cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size
|
||||
):
|
||||
cache_config.mamba_page_size_padded = attn_page_size
|
||||
mamba_padding_pct = (
|
||||
100 * (attn_page_size - mamba_page_size) / mamba_page_size
|
||||
)
|
||||
logger.info(
|
||||
"Padding mamba page size by %.2f%% to ensure "
|
||||
"that mamba page size and attention page size are "
|
||||
"exactly equal.",
|
||||
mamba_padding_pct,
|
||||
)
|
||||
|
||||
|
||||
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user