[Hybrid]: Decouple Kernel Block Size from KV Page Size (#24486)

Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
Zhiyuan Li
2025-10-09 14:43:39 +08:00
committed by GitHub
parent d17f0fbf30
commit d24cf322e1
18 changed files with 573 additions and 55 deletions

View File

@@ -365,6 +365,23 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
block_size=model_config.max_model_len,
).page_size_bytes
# Model may be marked as is_hybrid
# but mamba is skipped via config,
# return directly
if mamba_page_size == 0:
return
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: 128-byte alignment
# * Other MLA backends: 64-byte alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
else:
kernel_block_alignment_size = 16
if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
@@ -381,19 +398,28 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# TODO(tdoublep): this constraint can be relaxed fairly
# easily by changing the way we layout chunks in the
# mamba2 kernels.
chunk_size = model_config.get_mamba_chunk_size()
from math import gcd
def lcm(a, b):
return a * b // gcd(a, b)
base_chunk_size = model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
# Without prefix caching, select minimum valid attention block size
# to minimize mamba state padding
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token)
# Calculate minimum attention block size that satisfies both:
# 1. Backend alignment requirements (kernel_block_alignment_size)
# 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
attn_block_size = kernel_block_alignment_size * cdiv(
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it