[ROCm][Bugfix] Fix Mamba batched decode producing incorrect output (#32099)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_state_update,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
@@ -195,11 +196,12 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
def _ssm_transform(
|
||||
self, x: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.is_lora_enabled:
|
||||
# Lora kernel requires contiguous tensor.
|
||||
ssm_params = self.x_proj(x.contiguous())[0]
|
||||
else:
|
||||
ssm_params = self.x_proj(x)[0]
|
||||
# LoRA kernel requires contiguous tensor.
|
||||
# ROCm: Non-contiguous tensors cause incorrect GEMM
|
||||
# results when batch > 1.
|
||||
if self.is_lora_enabled or current_platform.is_rocm():
|
||||
x = x.contiguous()
|
||||
ssm_params = self.x_proj(x)[0]
|
||||
time_step, B, C = torch.split(
|
||||
ssm_params,
|
||||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||
|
||||
Reference in New Issue
Block a user