[ROCm][Bugfix][CI] Fix hybrid models and their tests (Mamba/Jamba/Bamba) (#32710)
Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import pytest
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
|
||||
@@ -577,6 +578,10 @@ def test_apc_multiple_prompts_all_cached_outputs(
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size
|
||||
)
|
||||
vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
|
||||
# Reduce the effects of batch variance on ROCm since batch invariance is not
|
||||
# yet supported. See: https://github.com/vllm-project/vllm/issues/27433
|
||||
if current_platform.is_rocm():
|
||||
vllm_runner_kwargs["max_num_seqs"] = 4
|
||||
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
|
||||
|
||||
Reference in New Issue
Block a user