[ROCm][Bugfix][CI] Fix hybrid models and their tests (Mamba/Jamba/Bamba) (#32710)
Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import pytest
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
|
||||
@@ -577,6 +578,10 @@ def test_apc_multiple_prompts_all_cached_outputs(
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size
|
||||
)
|
||||
vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
|
||||
# Reduce the effects of batch variance on ROCm since batch invariance is not
|
||||
# yet supported. See: https://github.com/vllm-project/vllm/issues/27433
|
||||
if current_platform.is_rocm():
|
||||
vllm_runner_kwargs["max_num_seqs"] = 4
|
||||
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
|
||||
|
||||
@@ -214,6 +214,12 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
time_step = self.dt_layernorm(time_step.contiguous())
|
||||
B = self.b_layernorm(B.contiguous())
|
||||
C = self.c_layernorm(C.contiguous())
|
||||
|
||||
# ROCm: tensor from split is non-contiguous, causing incorrect
|
||||
# GEMM results in dt_proj.
|
||||
if current_platform.is_rocm():
|
||||
time_step = time_step.contiguous()
|
||||
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
return discrete_time_step, B, C
|
||||
|
||||
|
||||
Reference in New Issue
Block a user