diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0e2e381f2..c9fc056be 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5698,28 +5698,23 @@ class GPUModelRunner( kv_cache_config: The KV cache configuration. kernel_block_sizes: The kernel block sizes for each KV cache group. """ - block_sizes = [ - kv_cache_group.kv_cache_spec.block_size - for kv_cache_group in kv_cache_config.kv_cache_groups - if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) - ] + block_sizes = [] max_num_blocks = [] max_model_len = max(self.max_model_len, self.max_encoder_len) - for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + for kv_cache_group in kv_cache_config.kv_cache_groups: if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): continue + block_size = kv_cache_group.kv_cache_spec.block_size + block_sizes.append(block_size) max_num_blocks_per_req = cdiv( - max_model_len, block_sizes[i] * get_total_cp_world_size() + max_model_len, block_size * get_total_cp_world_size() ) if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): - mamba_blocks_per_req = ( + max_num_blocks_per_req = ( max_num_blocks_per_req if self.cache_config.enable_prefix_caching else 1 ) + kv_cache_group.kv_cache_spec.num_speculative_blocks - max_num_blocks_per_req = max( - max_num_blocks_per_req, mamba_blocks_per_req - ) max_num_blocks.append(max_num_blocks_per_req) if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [