[BugFix] Avoid prefix cache hit in the same schedule step for mamba layers (#29387)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2026-02-09 23:41:16 -08:00
committed by GitHub
parent dab1de9f38
commit 97fa8f6590
6 changed files with 78 additions and 3 deletions

View File

@@ -7,6 +7,7 @@ import pytest
from tests.models.registry import HF_EXAMPLE_MODELS
from tests.utils import multi_gpu_test
from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
@@ -769,3 +770,30 @@ def test_apc_multiple_prompts_partial_cached_outputs(
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
# we have to use a real large model to get reasonable results
# the model can't be a hybrid model as we need block_size 16
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])
def test_apc_common_prefix_same_batch(
model: str,
monkeypatch,
) -> None:
# Required to put the two requests in the same batch
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
llm = LLM(
model=model,
enforce_eager=True,
block_size=16,
mamba_block_size=16,
enable_prefix_caching=True,
seed=42,
)
prompts = [
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
assert "two" in output.outputs[0].text

View File

@@ -857,6 +857,8 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
# Should have blocks for all groups
assert len(blocks.get_block_ids()) == num_groups
manager.new_step_starts()
# Second request: should hit cached blocks for common prefix
req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)

View File

@@ -247,6 +247,11 @@ class KVCacheCoordinator(ABC):
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
pass
def new_step_starts(self) -> None:
"""Called when a new step is started."""
for manager in self.single_type_managers:
manager.new_step_starts()
class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
"""

View File

@@ -488,3 +488,7 @@ class KVCacheManager:
) -> KVCacheBlocks:
# Only create new KVCacheBlocks for non-empty blocks
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
def new_step_starts(self) -> None:
"""Called when a new step is started."""
self.coordinator.new_step_starts()

View File

@@ -347,6 +347,8 @@ class Scheduler(SchedulerInterface):
# For logging.
scheduled_timestamp = time.monotonic()
self.kv_cache_manager.new_step_starts()
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:

View File

@@ -7,7 +7,11 @@ from collections.abc import Sequence
from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
from vllm.v1.core.kv_cache_utils import (
BlockHashList,
BlockHashWithGroupId,
KVCacheBlock,
)
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
@@ -396,6 +400,10 @@ class SingleTypeKVCacheManager(ABC):
# The default behavior is to not skip any tokens.
return 0
def new_step_starts(self) -> None:
# do nothing by default
return None
class FullAttentionManager(SingleTypeKVCacheManager):
@classmethod
@@ -742,8 +750,11 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
class MambaManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None:
super().__init__(kv_cache_spec, **kwargs)
def __init__(
self, kv_cache_spec: MambaSpec, block_pool: BlockPool, **kwargs
) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.cached_blocks_this_step: set[BlockHashWithGroupId] = set()
self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode
self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks
if self.mamba_cache_mode == "align":
@@ -838,6 +849,15 @@ class MambaManager(SingleTypeKVCacheManager):
num_tokens_main_model: int,
) -> int:
assert isinstance(self.kv_cache_spec, MambaSpec)
if (
len(new_computed_blocks) > 0
and new_computed_blocks[-1].block_hash in self.cached_blocks_this_step
):
# Mamba can't rely on blocks generated by other requests in the current step
# To put it in the next step, we return num_gpu_blocks + 1 so
# that kv_cache_manager will think there is no enough blocks to allocte now
# and don't schedule it in the current step.
return self.block_pool.num_gpu_blocks + 1
if self.mamba_cache_mode != "align":
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
@@ -972,6 +992,20 @@ class MambaManager(SingleTypeKVCacheManager):
"""
return num_computed_tokens - 1
def cache_blocks(self, request: Request, num_tokens: int) -> None:
num_cached_blocks_before = self.num_cached_block.get(request.request_id, 0)
super().cache_blocks(request, num_tokens)
num_cached_blocks_after = self.num_cached_block.get(request.request_id, 0)
if num_cached_blocks_after > num_cached_blocks_before:
for block in self.req_to_blocks[request.request_id][
num_cached_blocks_before:num_cached_blocks_after
]:
assert block.block_hash is not None
self.cached_blocks_this_step.add(block.block_hash)
def new_step_starts(self) -> None:
self.cached_blocks_this_step.clear()
class CrossAttentionManager(SingleTypeKVCacheManager):
"""Manager for cross-attention KV cache in encoder-decoder models."""