[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
@@ -270,6 +271,30 @@ def test_state_cleanup(
|
||||
"could be related to finished_requests_ids")
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
def test_jamba_distributed_produces_identical_generation(
|
||||
vllm_runner, model: str, dtype: str, max_tokens: int,
|
||||
example_prompts) -> None:
|
||||
|
||||
with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model:
|
||||
vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model:
|
||||
vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_outputs_tp_1,
|
||||
outputs_1_lst=vllm_outputs_tp_2,
|
||||
name_0="vllm_tp_1",
|
||||
name_1="vllm_tp_2",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_model_print(
|
||||
|
||||
Reference in New Issue
Block a user