[ROCm][Bugfix] Fix Mamba batched decode producing incorrect output (#32099)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-01-12 23:46:53 -06:00
committed by GitHub
parent 2a719e0865
commit 11b6af5280
2 changed files with 15 additions and 5 deletions

View File

@@ -63,6 +63,7 @@ from vllm.model_executor.models.utils import (
maybe_prefix,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
@@ -414,6 +415,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_state_indices=state_indices_tensor_d,
)
# ROCm: Ensure contiguous tensor for bcdt_proj linear layer.
# causal_conv1d_update returns a non-contiguous view (stride 8192
# instead of 4096 for shape [batch, 4096]), causing incorrect GEMM
# results when batch > 1 on ROCm.
if current_platform.is_rocm():
hidden_states_d = hidden_states_d.contiguous()
B, C, dt = self._project_ssm_parameters(hidden_states_d)
# 3. State Space Model sequence transformation