[ROCm][Bugfix] Fix Mamba batched decode producing incorrect output (#32099)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-01-12 23:46:53 -06:00
committed by GitHub
parent 2a719e0865
commit 11b6af5280
2 changed files with 15 additions and 5 deletions

View File

@@ -34,6 +34,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
@@ -195,11 +196,12 @@ class MambaMixer(MambaBase, CustomOp):
def _ssm_transform(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.is_lora_enabled:
# Lora kernel requires contiguous tensor.
ssm_params = self.x_proj(x.contiguous())[0]
else:
ssm_params = self.x_proj(x)[0]
# LoRA kernel requires contiguous tensor.
# ROCm: Non-contiguous tensors cause incorrect GEMM
# results when batch > 1.
if self.is_lora_enabled or current_platform.is_rocm():
x = x.contiguous()
ssm_params = self.x_proj(x)[0]
time_step, B, C = torch.split(
ssm_params,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],