[V1] [Hybrid] Mamba2 Automatic Prefix Caching (#25752)

Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Thomas Ortner <boh@zurich.ibm.com>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Stan Wozniak
2025-10-04 06:34:22 +02:00
committed by GitHub
parent 9705fba7b7
commit ea507c3a93
18 changed files with 917 additions and 147 deletions

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable
import pytest
from tests.models.registry import HF_EXAMPLE_MODELS
@@ -8,7 +10,7 @@ from tests.utils import multi_gpu_test
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams
from ...utils import check_logprobs_close
from ...utils import check_logprobs_close, check_outputs_equal
# Mark all tests as hybrid
pytestmark = pytest.mark.hybrid_model
@@ -332,3 +334,413 @@ def test_fp32_cache_state(
name_0="hf",
name_1="vllm",
)
# Helper functions for the APC tests
def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1):
return {
'model_name': model,
'enable_prefix_caching': False,
'max_model_len': max_model_len,
'tensor_parallel_size': tensor_parallel_size,
'gpu_memory_utilization': 0.4
}
def _get_vLLM_output(vllm_runner,
kwargs,
prompts,
max_tokens,
num_logprobs,
num_repetitions=1,
vllm_model=None):
outs = []
if vllm_model is None:
vllm_model = vllm_runner(**kwargs)
for _ in range(num_repetitions):
if num_logprobs < 0:
vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
else:
vllm_output = vllm_model.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs)
outs.append(vllm_output)
return outs, vllm_model
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_single_prompt(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
n_repetitions: int,
num_logprobs: int,
tensor_parallel_size: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
compare_operator: Callable = check_logprobs_close \
if num_logprobs > 0 else check_outputs_equal # type: ignore
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * example_prompts[0]]
max_model_len = max(
len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs)
vllm_runner_kwargs['enable_prefix_caching'] = True
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs, n_repetitions)
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_single_prompt_block_align_alignment(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
n_repetitions: int,
num_logprobs: int,
tensor_parallel_size: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
compare_operator: Callable = check_logprobs_close \
if num_logprobs > 0 else check_outputs_equal # type: ignore
MULTIPLE = 300
# Sample prompts. This custom prompt is used, as it causes the most issues
generated_prompts = ["The president of the United States is " * MULTIPLE]
max_model_len = max(
len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs)
vllm_runner_kwargs['enable_prefix_caching'] = True
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
# Retrieve the default mamba state block size
mamba_block_size = vllm_model.llm.llm_engine.cache_config. \
mamba_block_size
# In case the hybrid model does not have the
# "mamba_block_size" assume a fixed constant
if mamba_block_size is None:
mamba_block_size = 512
mamba_block_size_multiplier = 10
for offsets in [
-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3
]:
vllm_runner_kwargs[
'max_num_batched_tokens'] = mamba_block_size_multiplier * \
mamba_block_size - offsets
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts,
max_tokens, num_logprobs,
n_repetitions)
# Check alignment of the output logits when using APC
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_all_cached_outputs(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
n_repetitions: int,
num_logprobs: int,
tensor_parallel_size: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
compare_operator: Callable = check_logprobs_close \
if num_logprobs > 0 else check_outputs_equal # type: ignore
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
max_model_len = max(
len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs)
vllm_runner_kwargs['enable_prefix_caching'] = True
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs, n_repetitions)
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_block_align_alignment(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
n_repetitions: int,
num_logprobs: int,
tensor_parallel_size: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
compare_operator: Callable = check_logprobs_close \
if num_logprobs > 0 else check_outputs_equal # type: ignore
MULTIPLE = 300
# Sample prompts. This custom prompt is used, as it causes the most issues
prompt_text = "The president of the United States is "
prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
generated_prompts = [
prompt_text[offset:] * MULTIPLE for offset in prompt_offsets
]
max_model_len = max(
len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(model, max_model_len,
tensor_parallel_size)
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs)
vllm_runner_kwargs['enable_prefix_caching'] = True
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
# Retrieve the default mamba state block size
mamba_block_size = vllm_model.llm.llm_engine.cache_config. \
mamba_block_size
# In case the hybrid model does not have the
# "mamba_block_size" assume a fixed constant
if mamba_block_size is None:
mamba_block_size = 512
mamba_block_size_multiplier = 10
for offsets in [
-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3
]:
vllm_runner_kwargs[
'max_num_batched_tokens'] = mamba_block_size_multiplier * \
mamba_block_size - offsets
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts,
max_tokens, num_logprobs,
n_repetitions)
# Check alignment of the output logits when using APC
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_partial_cached_outputs(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
n_repetitions: int,
num_logprobs: int,
tensor_parallel_size: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
compare_operator: Callable = check_logprobs_close \
if num_logprobs > 0 else check_outputs_equal # type: ignore
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
max_model_len = max(
len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs)
# Cache only part of all the prompts
vllm_runner_kwargs['enable_prefix_caching'] = True
vllm_outputs_partial_cache, vllm_model = _get_vLLM_output(
vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens,
num_logprobs)
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0][:3],
outputs_1_lst=vllm_outputs_partial_cache[0],
name_0="vllm_no_cache",
name_1="vllm_partial_cache",
)
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts,
max_tokens,
num_logprobs,
n_repetitions,
vllm_model=vllm_model)
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)