[BugFix] Avoid prefix cache hit in the same schedule step for mamba layers (#29387)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2026-02-09 23:41:16 -08:00
committed by GitHub
parent dab1de9f38
commit 97fa8f6590
6 changed files with 78 additions and 3 deletions

View File

@@ -7,6 +7,7 @@ import pytest
from tests.models.registry import HF_EXAMPLE_MODELS
from tests.utils import multi_gpu_test
from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
@@ -769,3 +770,30 @@ def test_apc_multiple_prompts_partial_cached_outputs(
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
# we have to use a real large model to get reasonable results
# the model can't be a hybrid model as we need block_size 16
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])
def test_apc_common_prefix_same_batch(
model: str,
monkeypatch,
) -> None:
# Required to put the two requests in the same batch
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
llm = LLM(
model=model,
enforce_eager=True,
block_size=16,
mamba_block_size=16,
enable_prefix_caching=True,
seed=42,
)
prompts = [
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
"hello what is one plus one what is one plus one what is one plus one the answer is", # noqa: E501
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
assert "two" in output.outputs[0].text