[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -330,26 +330,54 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
if model_config.supports_mamba_prefix_caching:
|
||||
logger.info(
|
||||
"Warning: Prefix caching is currently enabled. "
|
||||
"Its support for Mamba layers is experimental. "
|
||||
"Please report any issues you may observe."
|
||||
if cache_config.mamba_cache_mode == "none":
|
||||
cache_config.mamba_cache_mode = (
|
||||
"all" if model_config.supports_mamba_prefix_caching else "align"
|
||||
)
|
||||
# By default, mamba block size will be set to max_model_len (see
|
||||
# below). When enabling prefix caching, we align mamba block size
|
||||
# to the block size as the basic granularity for prefix caching.
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
else:
|
||||
logger.info(
|
||||
"Hybrid or mamba-based model detected without "
|
||||
"support for prefix caching: disabling."
|
||||
logger.warning(
|
||||
"Mamba cache mode is set to '%s' for %s by default "
|
||||
"when prefix caching is enabled",
|
||||
cache_config.mamba_cache_mode,
|
||||
model_config.architecture,
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
if (
|
||||
cache_config.mamba_cache_mode == "all"
|
||||
and not model_config.supports_mamba_prefix_caching
|
||||
):
|
||||
cache_config.mamba_cache_mode = "align"
|
||||
logger.warning(
|
||||
"Hybrid or mamba-based model detected without support "
|
||||
"for prefix caching with Mamba cache 'all' mode: "
|
||||
"falling back to 'align' mode."
|
||||
)
|
||||
if cache_config.mamba_cache_mode == "align":
|
||||
assert vllm_config.scheduler_config.enable_chunked_prefill, (
|
||||
"Chunked prefill is required for mamba cache mode 'align'."
|
||||
)
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Mamba cache mode 'align' is currently not compatible "
|
||||
"with speculative decoding."
|
||||
)
|
||||
logger.info(
|
||||
"Warning: Prefix caching in Mamba cache '%s' "
|
||||
"mode is currently enabled. "
|
||||
"Its support for Mamba layers is experimental. "
|
||||
"Please report any issues you may observe.",
|
||||
cache_config.mamba_cache_mode,
|
||||
)
|
||||
# By default, mamba block size will be set to max_model_len (see
|
||||
# below). When enabling prefix caching, we align mamba block size
|
||||
# to the block size as the basic granularity for prefix caching.
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
else:
|
||||
if cache_config.mamba_cache_mode != "none":
|
||||
cache_config.mamba_cache_mode = "none"
|
||||
logger.warning(
|
||||
"Mamba cache mode is set to 'none' when prefix caching is disabled"
|
||||
)
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
|
||||
|
||||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
@@ -426,7 +454,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=model_config.max_model_len,
|
||||
block_size=-1, # block_size doesn't matter for mamba page size
|
||||
).page_size_bytes
|
||||
|
||||
# Model may be marked as is_hybrid
|
||||
@@ -435,7 +463,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
|
||||
@@ -479,6 +507,13 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
attn_block_size,
|
||||
)
|
||||
|
||||
# By default, mamba block size will be set to max_model_len.
|
||||
# When enabling prefix caching and using align mamba cache
|
||||
# mode, we align mamba block size to the block size as the
|
||||
# basic granularity for prefix caching.
|
||||
if cache_config.mamba_cache_mode == "align":
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
|
||||
Reference in New Issue
Block a user