[BugFix] Fix and optimize max_num_blocks_per_req calculation for MambaSpec (#34440)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
This commit is contained in:
Harry Huang
2026-02-13 16:13:14 +08:00
committed by GitHub
parent bcf0731aa0
commit 7a8a46ddcb

View File

@@ -5698,28 +5698,23 @@ class GPUModelRunner(
kv_cache_config: The KV cache configuration.
kernel_block_sizes: The kernel block sizes for each KV cache group.
"""
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
]
block_sizes = []
max_num_blocks = []
max_model_len = max(self.max_model_len, self.max_encoder_len)
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
for kv_cache_group in kv_cache_config.kv_cache_groups:
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
continue
block_size = kv_cache_group.kv_cache_spec.block_size
block_sizes.append(block_size)
max_num_blocks_per_req = cdiv(
max_model_len, block_sizes[i] * get_total_cp_world_size()
max_model_len, block_size * get_total_cp_world_size()
)
if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
mamba_blocks_per_req = (
max_num_blocks_per_req = (
max_num_blocks_per_req
if self.cache_config.enable_prefix_caching
else 1
) + kv_cache_group.kv_cache_spec.num_speculative_blocks
max_num_blocks_per_req = max(
max_num_blocks_per_req, mamba_blocks_per_req
)
max_num_blocks.append(max_num_blocks_per_req)
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [