[V1] [Hybrid] Mamba2 Automatic Prefix Caching (#25752)

Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Thomas Ortner <boh@zurich.ibm.com>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Stan Wozniak
2025-10-04 06:34:22 +02:00
committed by GitHub
parent 9705fba7b7
commit ea507c3a93
18 changed files with 917 additions and 147 deletions

View File

@@ -292,10 +292,33 @@ class MambaModelConfig(VerifyAndUpdateConfig):
cache_config = vllm_config.cache_config
compilation_config = vllm_config.compilation_config
# TODO(tdoublep): remove once prefix caching is enabled
cache_config.enable_prefix_caching = False
logger.info("Hybrid or mamba-based model detected: disabling prefix "
"caching since it is not yet supported.")
# Set mamba block size to max_model_len (this may get
# override by prefix caching logic later)
cache_config.mamba_block_size = model_config.max_model_len
# TODO(@tdoublep) find a better way to do this than whitelist
MAMBA2_MODELS = [
"BambaForCausalLM",
"FalconH1ForCausalLM",
"GraniteMoeHybridForCausalLM",
"Mamba2ForCausalLM",
"NemotronHForCausalLM",
"Zamba2ForCausalLM",
]
if cache_config.enable_prefix_caching:
if model_config.architecture in MAMBA2_MODELS:
logger.info("Warning: Prefix caching is currently enabled. "
"Its support for Mamba2 layers is experimental. "
"Please report any issues you may observe.")
else:
logger.info("Hybrid or mamba-based model detected without "
"support for prefix caching: disabling.")
cache_config.enable_prefix_caching = False
# TODO(tdoublep): remove once cascade attention is supported
logger.info("Disabling cascade attention since it is not supported "
"for hybrid models.")
model_config.disable_cascade_attn = True
# TODO(tdoublep): remove as full cuda graph support is added
FCG_NOT_SUPPORTED_MODELS = [
@@ -360,12 +383,38 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
block_size=model_config.max_model_len,
).page_size_bytes
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = 16 * cdiv(mamba_page_size,
16 * attn_page_size_1_token)
if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
# mamba SSD kernel uses a chunk_size, e.g. 256
# Align the block to the kernel: use lowest multiple of chunk_size
# of attention tokens that would fit mamba_page_size:
# e.g. for mamba page size = 788kB
# attn_1_token = 2kB -> fits ~394 tokens
# then round up to a mulitple of 256 -> 512 tokens
# End result:
# attn_block_size = 512
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
# TODO(tdoublep): this constraint can be relaxed fairly
# easily by changing the way we layout chunks in the
# mamba2 kernels.
chunk_size = model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = \
cdiv(mamba_page_size, attn_page_size_1_token)
attn_block_size = chunk_size * \
cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
# Without prefix caching, select minimum valid attention block size
# to minimize mamba state padding
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = 16 * cdiv(mamba_page_size,
16 * attn_page_size_1_token)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it