[CI] Fix conch kernel crash on 3D input by reshaping to 2D before GEMM (#38178)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-26 11:46:03 -05:00
committed by GitHub
parent b9dbc5c4ab
commit a8e48a7b85

View File

@@ -134,8 +134,11 @@ class ConchLinearKernel(MPLinearKernel):
if group_size == -1:
group_size = x.shape[-1]
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (self.config.partition_weight_shape[1],)
output = mixed_precision_gemm(
x=x,
x=x_2d,
w_q_packed=w_q.data,
w_s=w_s.data,
w_zp=w_zp.data if w_zp is not None else None,
@@ -147,4 +150,4 @@ class ConchLinearKernel(MPLinearKernel):
if bias is not None:
output.add_(bias) # In-place add
return output
return output.reshape(out_shape)