[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -3,9 +3,14 @@
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -200,6 +205,91 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
}
|
||||
|
||||
|
||||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Ensure that page size of attention layers is greater than or
|
||||
equal to the mamba layers. If not, automatically set the attention
|
||||
block size to ensure that it is. If the attention page size is
|
||||
strictly greater than the mamba page size, we pad the mamba page size
|
||||
to make them equal.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM Config
|
||||
"""
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
return
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# get attention page size (for 1 token)
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
use_mla=model_config.use_mla).page_size_bytes
|
||||
|
||||
model_cls = ModelRegistry.resolve_model_cls(
|
||||
model_config._model_info.architecture)[0]
|
||||
|
||||
# get mamba page size
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtype=kv_cache_dtype,
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = 16 * cdiv(mamba_page_size,
|
||||
16 * attn_page_size_1_token)
|
||||
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
# too small.
|
||||
if (cache_config.block_size is None
|
||||
or cache_config.block_size < attn_block_size):
|
||||
cache_config.block_size = attn_block_size
|
||||
logger.info(
|
||||
"Setting attention block size to %d tokens "
|
||||
"to ensure that attention page size is >= mamba page size.",
|
||||
attn_block_size)
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = \
|
||||
cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
assert attn_page_size >= mamba_page_size
|
||||
|
||||
if attn_page_size == mamba_page_size:
|
||||
# don't need to pad mamba page size
|
||||
return
|
||||
|
||||
# pad mamba page size to exactly match attention
|
||||
if (cache_config.mamba_page_size_padded is None
|
||||
or cache_config.mamba_page_size_padded != attn_page_size):
|
||||
cache_config.mamba_page_size_padded = (attn_page_size)
|
||||
mamba_padding_pct = 100 * (attn_page_size -
|
||||
mamba_page_size) / mamba_page_size
|
||||
logger.info(
|
||||
"Padding mamba page size by %.2f%% to ensure "
|
||||
"that mamba page size and attention page size are "
|
||||
"exactly equal.", mamba_padding_pct)
|
||||
|
||||
|
||||
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewModel": GteNewModelConfig,
|
||||
|
||||
Reference in New Issue
Block a user