diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2724f612c..e853f65db 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -7,6 +7,7 @@ import pytest from tests.models.registry import HF_EXAMPLE_MODELS from tests.utils import multi_gpu_test +from vllm import LLM from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams @@ -769,3 +770,30 @@ def test_apc_multiple_prompts_partial_cached_outputs( name_0="vllm_no_cache", name_1=f"vllm_cache_it_{r_idx + 1}", ) + + +# we have to use a real large model to get reasonable results +# the model can't be a hybrid model as we need block_size 16 +@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"]) +def test_apc_common_prefix_same_batch( + model: str, + monkeypatch, +) -> None: + # Required to put the two requests in the same batch + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + llm = LLM( + model=model, + enforce_eager=True, + block_size=16, + mamba_block_size=16, + enable_prefix_caching=True, + seed=42, + ) + prompts = [ + "hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501 + "hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501 + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20) + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + assert "two" in output.outputs[0].text diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 287b8ad98..e2c924a61 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -857,6 +857,8 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]): # Should have blocks for all groups assert len(blocks.get_block_ids()) == num_groups + manager.new_step_starts() + # Second request: should hit cached blocks for common prefix req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index d8f9d69c7..eaa95dfe4 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -247,6 +247,11 @@ class KVCacheCoordinator(ABC): ) -> tuple[tuple[list[KVCacheBlock], ...], int]: pass + def new_step_starts(self) -> None: + """Called when a new step is started.""" + for manager in self.single_type_managers: + manager.new_step_starts() + class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): """ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 2caed0493..7f8d80475 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -488,3 +488,7 @@ class KVCacheManager: ) -> KVCacheBlocks: # Only create new KVCacheBlocks for non-empty blocks return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks + + def new_step_starts(self) -> None: + """Called when a new step is started.""" + self.coordinator.new_step_starts() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index aa3bc6e2c..cfd6baabb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -347,6 +347,8 @@ class Scheduler(SchedulerInterface): # For logging. scheduled_timestamp = time.monotonic() + self.kv_cache_manager.new_step_starts() + # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 96660dc6f..0b6b7ed42 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -7,7 +7,11 @@ from collections.abc import Sequence from vllm.utils.math_utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock +from vllm.v1.core.kv_cache_utils import ( + BlockHashList, + BlockHashWithGroupId, + KVCacheBlock, +) from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, CrossAttentionSpec, @@ -396,6 +400,10 @@ class SingleTypeKVCacheManager(ABC): # The default behavior is to not skip any tokens. return 0 + def new_step_starts(self) -> None: + # do nothing by default + return None + class FullAttentionManager(SingleTypeKVCacheManager): @classmethod @@ -742,8 +750,11 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): class MambaManager(SingleTypeKVCacheManager): - def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None: - super().__init__(kv_cache_spec, **kwargs) + def __init__( + self, kv_cache_spec: MambaSpec, block_pool: BlockPool, **kwargs + ) -> None: + super().__init__(kv_cache_spec, block_pool, **kwargs) + self.cached_blocks_this_step: set[BlockHashWithGroupId] = set() self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks if self.mamba_cache_mode == "align": @@ -838,6 +849,15 @@ class MambaManager(SingleTypeKVCacheManager): num_tokens_main_model: int, ) -> int: assert isinstance(self.kv_cache_spec, MambaSpec) + if ( + len(new_computed_blocks) > 0 + and new_computed_blocks[-1].block_hash in self.cached_blocks_this_step + ): + # Mamba can't rely on blocks generated by other requests in the current step + # To put it in the next step, we return num_gpu_blocks + 1 so + # that kv_cache_manager will think there is no enough blocks to allocte now + # and don't schedule it in the current step. + return self.block_pool.num_gpu_blocks + 1 if self.mamba_cache_mode != "align": # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. @@ -972,6 +992,20 @@ class MambaManager(SingleTypeKVCacheManager): """ return num_computed_tokens - 1 + def cache_blocks(self, request: Request, num_tokens: int) -> None: + num_cached_blocks_before = self.num_cached_block.get(request.request_id, 0) + super().cache_blocks(request, num_tokens) + num_cached_blocks_after = self.num_cached_block.get(request.request_id, 0) + if num_cached_blocks_after > num_cached_blocks_before: + for block in self.req_to_blocks[request.request_id][ + num_cached_blocks_before:num_cached_blocks_after + ]: + assert block.block_hash is not None + self.cached_blocks_this_step.add(block.block_hash) + + def new_step_starts(self) -> None: + self.cached_blocks_this_step.clear() + class CrossAttentionManager(SingleTypeKVCacheManager): """Manager for cross-attention KV cache in encoder-decoder models."""