[BugFix] Avoid prefix cache hit in the same schedule step for mamba layers (#29387)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm import LLM
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -769,3 +770,30 @@ def test_apc_multiple_prompts_partial_cached_outputs(
|
||||
name_0="vllm_no_cache",
|
||||
name_1=f"vllm_cache_it_{r_idx + 1}",
|
||||
)
|
||||
|
||||
|
||||
# we have to use a real large model to get reasonable results
|
||||
# the model can't be a hybrid model as we need block_size 16
|
||||
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])
|
||||
def test_apc_common_prefix_same_batch(
|
||||
model: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
# Required to put the two requests in the same batch
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
llm = LLM(
|
||||
model=model,
|
||||
enforce_eager=True,
|
||||
block_size=16,
|
||||
mamba_block_size=16,
|
||||
enable_prefix_caching=True,
|
||||
seed=42,
|
||||
)
|
||||
prompts = [
|
||||
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
|
||||
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
assert "two" in output.outputs[0].text
|
||||
|
||||
@@ -857,6 +857,8 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
|
||||
# Should have blocks for all groups
|
||||
assert len(blocks.get_block_ids()) == num_groups
|
||||
|
||||
manager.new_step_starts()
|
||||
|
||||
# Second request: should hit cached blocks for common prefix
|
||||
req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
|
||||
@@ -247,6 +247,11 @@ class KVCacheCoordinator(ABC):
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
pass
|
||||
|
||||
def new_step_starts(self) -> None:
|
||||
"""Called when a new step is started."""
|
||||
for manager in self.single_type_managers:
|
||||
manager.new_step_starts()
|
||||
|
||||
|
||||
class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
"""
|
||||
|
||||
@@ -488,3 +488,7 @@ class KVCacheManager:
|
||||
) -> KVCacheBlocks:
|
||||
# Only create new KVCacheBlocks for non-empty blocks
|
||||
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
|
||||
|
||||
def new_step_starts(self) -> None:
|
||||
"""Called when a new step is started."""
|
||||
self.coordinator.new_step_starts()
|
||||
|
||||
@@ -347,6 +347,8 @@ class Scheduler(SchedulerInterface):
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
self.kv_cache_manager.new_step_starts()
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
|
||||
@@ -7,7 +7,11 @@ from collections.abc import Sequence
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHashList,
|
||||
BlockHashWithGroupId,
|
||||
KVCacheBlock,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
@@ -396,6 +400,10 @@ class SingleTypeKVCacheManager(ABC):
|
||||
# The default behavior is to not skip any tokens.
|
||||
return 0
|
||||
|
||||
def new_step_starts(self) -> None:
|
||||
# do nothing by default
|
||||
return None
|
||||
|
||||
|
||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
@@ -742,8 +750,11 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
|
||||
class MambaManager(SingleTypeKVCacheManager):
|
||||
def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, **kwargs)
|
||||
def __init__(
|
||||
self, kv_cache_spec: MambaSpec, block_pool: BlockPool, **kwargs
|
||||
) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
self.cached_blocks_this_step: set[BlockHashWithGroupId] = set()
|
||||
self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode
|
||||
self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks
|
||||
if self.mamba_cache_mode == "align":
|
||||
@@ -838,6 +849,15 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
num_tokens_main_model: int,
|
||||
) -> int:
|
||||
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||
if (
|
||||
len(new_computed_blocks) > 0
|
||||
and new_computed_blocks[-1].block_hash in self.cached_blocks_this_step
|
||||
):
|
||||
# Mamba can't rely on blocks generated by other requests in the current step
|
||||
# To put it in the next step, we return num_gpu_blocks + 1 so
|
||||
# that kv_cache_manager will think there is no enough blocks to allocte now
|
||||
# and don't schedule it in the current step.
|
||||
return self.block_pool.num_gpu_blocks + 1
|
||||
if self.mamba_cache_mode != "align":
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
@@ -972,6 +992,20 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
"""
|
||||
return num_computed_tokens - 1
|
||||
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
num_cached_blocks_before = self.num_cached_block.get(request.request_id, 0)
|
||||
super().cache_blocks(request, num_tokens)
|
||||
num_cached_blocks_after = self.num_cached_block.get(request.request_id, 0)
|
||||
if num_cached_blocks_after > num_cached_blocks_before:
|
||||
for block in self.req_to_blocks[request.request_id][
|
||||
num_cached_blocks_before:num_cached_blocks_after
|
||||
]:
|
||||
assert block.block_hash is not None
|
||||
self.cached_blocks_this_step.add(block.block_hash)
|
||||
|
||||
def new_step_starts(self) -> None:
|
||||
self.cached_blocks_this_step.clear()
|
||||
|
||||
|
||||
class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
"""Manager for cross-attention KV cache in encoder-decoder models."""
|
||||
|
||||
Reference in New Issue
Block a user