[ROCm][Bugfix] Fix Mamba batched decode producing incorrect output (#32099)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -63,6 +63,7 @@ from vllm.model_executor.models.utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
@@ -414,6 +415,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
)
|
||||
|
||||
# ROCm: Ensure contiguous tensor for bcdt_proj linear layer.
|
||||
# causal_conv1d_update returns a non-contiguous view (stride 8192
|
||||
# instead of 4096 for shape [batch, 4096]), causing incorrect GEMM
|
||||
# results when batch > 1 on ROCm.
|
||||
if current_platform.is_rocm():
|
||||
hidden_states_d = hidden_states_d.contiguous()
|
||||
|
||||
B, C, dt = self._project_ssm_parameters(hidden_states_d)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
|
||||
Reference in New Issue
Block a user