diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 596c74c73..c22a309ce 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -34,6 +34,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_state_update, ) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -195,11 +196,12 @@ class MambaMixer(MambaBase, CustomOp): def _ssm_transform( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if self.is_lora_enabled: - # Lora kernel requires contiguous tensor. - ssm_params = self.x_proj(x.contiguous())[0] - else: - ssm_params = self.x_proj(x)[0] + # LoRA kernel requires contiguous tensor. + # ROCm: Non-contiguous tensors cause incorrect GEMM + # results when batch > 1. + if self.is_lora_enabled or current_platform.is_rocm(): + x = x.contiguous() + ssm_params = self.x_proj(x)[0] time_step, B, C = torch.split( ssm_params, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 225e131ec..45512d23d 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -63,6 +63,7 @@ from vllm.model_executor.models.utils import ( maybe_prefix, ) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionMetadata @@ -414,6 +415,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp): conv_state_indices=state_indices_tensor_d, ) + # ROCm: Ensure contiguous tensor for bcdt_proj linear layer. + # causal_conv1d_update returns a non-contiguous view (stride 8192 + # instead of 4096 for shape [batch, 4096]), causing incorrect GEMM + # results when batch > 1 on ROCm. + if current_platform.is_rocm(): + hidden_states_d = hidden_states_d.contiguous() + B, C, dt = self._project_ssm_parameters(hidden_states_d) # 3. State Space Model sequence transformation