[V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin
2025-11-02 14:16:23 +02:00
committed by GitHub
parent 73444b7b56
commit 00b31a36a2
16 changed files with 442 additions and 153 deletions

View File

@@ -19,6 +19,8 @@ pytestmark = pytest.mark.hybrid_model
# meaning that it will be used in all tests in this file
# The rest of the models will only be tested by test_models
APC_MULTIPLY_BY = 300
SSM_MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
@@ -380,7 +382,7 @@ def _get_vLLM_output(
return outs, vllm_model
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
@@ -410,10 +412,8 @@ def test_apc_single_prompt(
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
)
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * example_prompts[0]]
generated_prompts = [APC_MULTIPLY_BY * example_prompts[0]]
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
@@ -446,7 +446,7 @@ def test_apc_single_prompt(
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
@@ -476,10 +476,8 @@ def test_apc_single_prompt_block_align_alignment(
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
)
MULTIPLE = 300
# Sample prompts. This custom prompt is used, as it causes the most issues
generated_prompts = ["The president of the United States is " * MULTIPLE]
generated_prompts = ["The president of the United States is " * APC_MULTIPLY_BY]
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
@@ -528,7 +526,7 @@ def test_apc_single_prompt_block_align_alignment(
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
@@ -558,10 +556,8 @@ def test_apc_multiple_prompts_all_cached_outputs(
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
)
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
@@ -595,7 +591,7 @@ def test_apc_multiple_prompts_all_cached_outputs(
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
@@ -625,12 +621,12 @@ def test_apc_multiple_prompts_block_align_alignment(
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
)
MULTIPLE = 300
# Sample prompts. This custom prompt is used, as it causes the most issues
prompt_text = "The president of the United States is "
prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets]
generated_prompts = [
prompt_text[offset:] * APC_MULTIPLY_BY for offset in prompt_offsets
]
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
@@ -679,7 +675,7 @@ def test_apc_multiple_prompts_block_align_alignment(
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
@@ -709,10 +705,8 @@ def test_apc_multiple_prompts_partial_cached_outputs(
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
)
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(