[Hybrid]: Decouple Kernel Block Size from KV Page Size (#24486)
Signed-off-by: lizhiyuan <uniartisan2017@gmail.com> Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
@@ -365,6 +365,23 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
|
||||
# Model may be marked as is_hybrid
|
||||
# but mamba is skipped via config,
|
||||
# return directly
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
# Attention backend constraints:
|
||||
# - FlashAttention (FA) requires block size to be multiple of 16
|
||||
# - MLA (Multi-head Latent Attention) requires larger alignment:
|
||||
# * CUTLASS_MLA backend: 128-byte alignment
|
||||
# * Other MLA backends: 64-byte alignment
|
||||
if model_config.use_mla:
|
||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
||||
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
||||
else:
|
||||
kernel_block_alignment_size = 16
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
@@ -381,19 +398,28 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# TODO(tdoublep): this constraint can be relaxed fairly
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
chunk_size = model_config.get_mamba_chunk_size()
|
||||
|
||||
from math import gcd
|
||||
|
||||
def lcm(a, b):
|
||||
return a * b // gcd(a, b)
|
||||
|
||||
base_chunk_size = model_config.get_mamba_chunk_size()
|
||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
|
||||
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
||||
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
else:
|
||||
# Without prefix caching, select minimum valid attention block size
|
||||
# to minimize mamba state padding
|
||||
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token)
|
||||
# Calculate minimum attention block size that satisfies both:
|
||||
# 1. Backend alignment requirements (kernel_block_alignment_size)
|
||||
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
|
||||
attn_block_size = kernel_block_alignment_size * cdiv(
|
||||
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
|
||||
)
|
||||
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
|
||||
Reference in New Issue
Block a user