[BugFix][Model] Jamba - Handle aborted requests, Add tests and fix cleanup bug (#6425)

Co-authored-by: Mor Zusman <morz@ai21.com>
This commit is contained in:
Mor Zusman
2024-07-16 04:32:55 +03:00
committed by GitHub
parent d6f3b3d5c4
commit 9ad32dacd9
5 changed files with 176 additions and 24 deletions

View File

@@ -1,5 +1,6 @@
import pytest
from tests.models.utils import check_outputs_equal
from vllm.worker.model_runner import _get_graph_batch_size
MODELS = ["ai21labs/Jamba-tiny-random"]
@@ -34,6 +35,34 @@ def test_models(
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_batching(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = []
with vllm_runner(model, dtype=dtype) as vllm_model:
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])
batched_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
check_outputs_equal(
outputs_0_lst=for_loop_outputs,
outputs_1_lst=batched_outputs,
name_0="for_loop_vllm",
name_1="batched_vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
@@ -60,6 +89,60 @@ def test_mamba_cache_cg_padding(
"Could be related to mamba cache not padded correctly")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="vllm_preepmtions",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Jamba inner state management doesn't
# collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum mamba block capacity.
# This could generally happen due to the fact that Jamba does support
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(