[V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
73444b7b56
commit
00b31a36a2
@@ -19,6 +19,8 @@ pytestmark = pytest.mark.hybrid_model
|
||||
# meaning that it will be used in all tests in this file
|
||||
# The rest of the models will only be tested by test_models
|
||||
|
||||
APC_MULTIPLY_BY = 300
|
||||
|
||||
SSM_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
@@ -380,7 +382,7 @@ def _get_vLLM_output(
|
||||
return outs, vllm_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
@@ -410,10 +412,8 @@ def test_apc_single_prompt(
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * example_prompts[0]]
|
||||
generated_prompts = [APC_MULTIPLY_BY * example_prompts[0]]
|
||||
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
@@ -446,7 +446,7 @@ def test_apc_single_prompt(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
@@ -476,10 +476,8 @@ def test_apc_single_prompt_block_align_alignment(
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts. This custom prompt is used, as it causes the most issues
|
||||
generated_prompts = ["The president of the United States is " * MULTIPLE]
|
||||
generated_prompts = ["The president of the United States is " * APC_MULTIPLY_BY]
|
||||
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
@@ -528,7 +526,7 @@ def test_apc_single_prompt_block_align_alignment(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
@@ -558,10 +556,8 @@ def test_apc_multiple_prompts_all_cached_outputs(
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
|
||||
generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
|
||||
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
@@ -595,7 +591,7 @@ def test_apc_multiple_prompts_all_cached_outputs(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
@@ -625,12 +621,12 @@ def test_apc_multiple_prompts_block_align_alignment(
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts. This custom prompt is used, as it causes the most issues
|
||||
prompt_text = "The president of the United States is "
|
||||
prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
|
||||
generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets]
|
||||
generated_prompts = [
|
||||
prompt_text[offset:] * APC_MULTIPLY_BY for offset in prompt_offsets
|
||||
]
|
||||
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
@@ -679,7 +675,7 @@ def test_apc_multiple_prompts_block_align_alignment(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
@@ -709,10 +705,8 @@ def test_apc_multiple_prompts_partial_cached_outputs(
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
|
||||
generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
|
||||
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
|
||||
Reference in New Issue
Block a user