[ROCm][Bugfix][CI] Fix hybrid models and their tests (Mamba/Jamba/Bamba) (#32710)
Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
@@ -214,6 +214,12 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
time_step = self.dt_layernorm(time_step.contiguous())
|
||||
B = self.b_layernorm(B.contiguous())
|
||||
C = self.c_layernorm(C.contiguous())
|
||||
|
||||
# ROCm: tensor from split is non-contiguous, causing incorrect
|
||||
# GEMM results in dt_proj.
|
||||
if current_platform.is_rocm():
|
||||
time_step = time_step.contiguous()
|
||||
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
return discrete_time_step, B, C
|
||||
|
||||
|
||||
Reference in New Issue
Block a user