[V1] [Hybrid] Mamba2 Automatic Prefix Caching (#25752)
Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com> Signed-off-by: Thomas Ortner <boh@zurich.ibm.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Ortner <boh@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -292,10 +292,33 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
cache_config = vllm_config.cache_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
# TODO(tdoublep): remove once prefix caching is enabled
|
||||
cache_config.enable_prefix_caching = False
|
||||
logger.info("Hybrid or mamba-based model detected: disabling prefix "
|
||||
"caching since it is not yet supported.")
|
||||
# Set mamba block size to max_model_len (this may get
|
||||
# override by prefix caching logic later)
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
|
||||
# TODO(@tdoublep) find a better way to do this than whitelist
|
||||
MAMBA2_MODELS = [
|
||||
"BambaForCausalLM",
|
||||
"FalconH1ForCausalLM",
|
||||
"GraniteMoeHybridForCausalLM",
|
||||
"Mamba2ForCausalLM",
|
||||
"NemotronHForCausalLM",
|
||||
"Zamba2ForCausalLM",
|
||||
]
|
||||
if cache_config.enable_prefix_caching:
|
||||
if model_config.architecture in MAMBA2_MODELS:
|
||||
logger.info("Warning: Prefix caching is currently enabled. "
|
||||
"Its support for Mamba2 layers is experimental. "
|
||||
"Please report any issues you may observe.")
|
||||
else:
|
||||
logger.info("Hybrid or mamba-based model detected without "
|
||||
"support for prefix caching: disabling.")
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
# TODO(tdoublep): remove once cascade attention is supported
|
||||
logger.info("Disabling cascade attention since it is not supported "
|
||||
"for hybrid models.")
|
||||
model_config.disable_cascade_attn = True
|
||||
|
||||
# TODO(tdoublep): remove as full cuda graph support is added
|
||||
FCG_NOT_SUPPORTED_MODELS = [
|
||||
@@ -360,12 +383,38 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = 16 * cdiv(mamba_page_size,
|
||||
16 * attn_page_size_1_token)
|
||||
if cache_config.enable_prefix_caching:
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
|
||||
# mamba SSD kernel uses a chunk_size, e.g. 256
|
||||
# Align the block to the kernel: use lowest multiple of chunk_size
|
||||
# of attention tokens that would fit mamba_page_size:
|
||||
# e.g. for mamba page size = 788kB
|
||||
# attn_1_token = 2kB -> fits ~394 tokens
|
||||
# then round up to a mulitple of 256 -> 512 tokens
|
||||
# End result:
|
||||
# attn_block_size = 512
|
||||
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
|
||||
# TODO(tdoublep): this constraint can be relaxed fairly
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
chunk_size = model_config.get_mamba_chunk_size()
|
||||
attn_tokens_per_mamba_state = \
|
||||
cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
attn_block_size = chunk_size * \
|
||||
cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
else:
|
||||
# Without prefix caching, select minimum valid attention block size
|
||||
# to minimize mamba state padding
|
||||
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = 16 * cdiv(mamba_page_size,
|
||||
16 * attn_page_size_1_token)
|
||||
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
|
||||
Reference in New Issue
Block a user