[BugFix] Fix and optimize max_num_blocks_per_req calculation for MambaSpec (#34440)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
This commit is contained in:
@@ -5698,28 +5698,23 @@ class GPUModelRunner(
|
||||
kv_cache_config: The KV cache configuration.
|
||||
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||
"""
|
||||
block_sizes = [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
|
||||
]
|
||||
block_sizes = []
|
||||
max_num_blocks = []
|
||||
max_model_len = max(self.max_model_len, self.max_encoder_len)
|
||||
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
continue
|
||||
block_size = kv_cache_group.kv_cache_spec.block_size
|
||||
block_sizes.append(block_size)
|
||||
max_num_blocks_per_req = cdiv(
|
||||
max_model_len, block_sizes[i] * get_total_cp_world_size()
|
||||
max_model_len, block_size * get_total_cp_world_size()
|
||||
)
|
||||
if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
|
||||
mamba_blocks_per_req = (
|
||||
max_num_blocks_per_req = (
|
||||
max_num_blocks_per_req
|
||||
if self.cache_config.enable_prefix_caching
|
||||
else 1
|
||||
) + kv_cache_group.kv_cache_spec.num_speculative_blocks
|
||||
max_num_blocks_per_req = max(
|
||||
max_num_blocks_per_req, mamba_blocks_per_req
|
||||
)
|
||||
max_num_blocks.append(max_num_blocks_per_req)
|
||||
|
||||
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
|
||||
|
||||
Reference in New Issue
Block a user