[Kernel][Model] Varlen prefill + Prefill chunking support for mamba kernels and Jamba model (#8533)
This commit is contained in:
@@ -1,18 +1,16 @@
|
||||
import pytest
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
|
||||
MODELS = ["ai21labs/Jamba-tiny-random"]
|
||||
MODELS = ["ai21labs/Jamba-tiny-dev"]
|
||||
|
||||
|
||||
# Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
|
||||
# TODO: Fix this with trained model
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@@ -22,7 +20,14 @@ def test_models(
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
with hf_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
model_kwargs={
|
||||
"use_mamba_kernels":
|
||||
False, # mamba kernels are not installed so HF
|
||||
# don't use them
|
||||
}) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
@@ -38,8 +43,8 @@ def test_models(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_batching(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
@@ -65,6 +70,107 @@ def test_batching(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float16"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_mamba_prefill_chunking_with_parallel_sampling(
|
||||
hf_runner, vllm_runner, example_prompts, model: str, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
# Tests prefill chunking in conjunction with n>1, in this case,
|
||||
# prefill is populated with decoding tokens and we test that it
|
||||
# doesn't fail This test might fail if cache is not allocated
|
||||
# correctly for n > 1 decoding steps inside a
|
||||
# chunked prefill forward pass (where we have both prefills
|
||||
# and decoding together )
|
||||
sampling_params = SamplingParams(n=3,
|
||||
temperature=1,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=30,
|
||||
max_num_seqs=10 # forces prefill chunks with decoding
|
||||
) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
|
||||
model: str, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
# numeric error during prefill chucking produces different generation
|
||||
# compared to w/o prefill chunking for those examples, removed them for now
|
||||
example_prompts.pop(7)
|
||||
example_prompts.pop(2)
|
||||
example_prompts.pop(1)
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
model_kwargs={
|
||||
"use_mamba_kernels":
|
||||
False, # mamba kernels are not installed so HF
|
||||
# don't use them
|
||||
}) as hf_model:
|
||||
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=5,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
chunked = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=chunked,
|
||||
outputs_1_lst=non_chunked,
|
||||
name_0="chunked",
|
||||
name_1="non_chunked",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [15])
|
||||
def test_parallel_sampling(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
for_loop_outputs = []
|
||||
for _ in range(10):
|
||||
for_loop_outputs.append(
|
||||
# using example_prompts index 1 instead of 0 since with 0 the
|
||||
# logprobs get really close and the test doesn't pass
|
||||
vllm_model.generate_greedy([example_prompts[1]], max_tokens)
|
||||
[0])
|
||||
sampling_params = SamplingParams(n=10,
|
||||
temperature=0.001,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
|
||||
sampling_params)
|
||||
token_ids, texts = n_lt_1_outputs[0]
|
||||
n_lt_1_outputs = [(token_id, text)
|
||||
for token_id, text in zip(token_ids, texts)]
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=n_lt_1_outputs,
|
||||
outputs_1_lst=for_loop_outputs,
|
||||
name_0="vllm_n_lt_1_outputs",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
|
||||
Reference in New Issue
Block a user