From 3e472e81f99b5bcf494369ee2d26ee9d6ceeffe3 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Thu, 5 Feb 2026 04:01:23 -0600 Subject: [PATCH] [ROCm][Bugfix][CI] Fix hybrid models and their tests (Mamba/Jamba/Bamba) (#32710) Signed-off-by: Andreas Karatzas Signed-off-by: Matthew Wong Co-authored-by: Matthew Wong --- tests/models/language/generation/test_hybrid.py | 5 +++++ vllm/model_executor/layers/mamba/mamba_mixer.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index c3e6d7899..2724f612c 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -8,6 +8,7 @@ import pytest from tests.models.registry import HF_EXAMPLE_MODELS from tests.utils import multi_gpu_test from vllm.engine.arg_utils import EngineArgs +from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher @@ -577,6 +578,10 @@ def test_apc_multiple_prompts_all_cached_outputs( model, max_model_len, tensor_parallel_size=tensor_parallel_size ) vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + # Reduce the effects of batch variance on ROCm since batch invariance is not + # yet supported. See: https://github.com/vllm-project/vllm/issues/27433 + if current_platform.is_rocm(): + vllm_runner_kwargs["max_num_seqs"] = 4 vllm_outputs_no_cache, _ = _get_vLLM_output( vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 134e1dfd6..adc643c38 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -214,6 +214,12 @@ class MambaMixer(MambaBase, CustomOp): time_step = self.dt_layernorm(time_step.contiguous()) B = self.b_layernorm(B.contiguous()) C = self.c_layernorm(C.contiguous()) + + # ROCm: tensor from split is non-contiguous, causing incorrect + # GEMM results in dt_proj. + if current_platform.is_rocm(): + time_step = time_step.contiguous() + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) return discrete_time_step, B, C