[ROCm][Bugfix][CI] Fix hybrid models and their tests (Mamba/Jamba/Bamba) (#32710)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-05 04:01:23 -06:00
committed by GitHub
parent 038914b7c8
commit 3e472e81f9
2 changed files with 11 additions and 0 deletions

View File

@@ -214,6 +214,12 @@ class MambaMixer(MambaBase, CustomOp):
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
# ROCm: tensor from split is non-contiguous, causing incorrect
# GEMM results in dt_proj.
if current_platform.is_rocm():
time_step = time_step.contiguous()
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
return discrete_time_step, B, C