[CI] Fix conch kernel crash on 3D input by reshaping to 2D before GEMM (#38178)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -134,8 +134,11 @@ class ConchLinearKernel(MPLinearKernel):
|
||||
if group_size == -1:
|
||||
group_size = x.shape[-1]
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (self.config.partition_weight_shape[1],)
|
||||
|
||||
output = mixed_precision_gemm(
|
||||
x=x,
|
||||
x=x_2d,
|
||||
w_q_packed=w_q.data,
|
||||
w_s=w_s.data,
|
||||
w_zp=w_zp.data if w_zp is not None else None,
|
||||
@@ -147,4 +150,4 @@ class ConchLinearKernel(MPLinearKernel):
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
return output.reshape(out_shape)
|
||||
|
||||
Reference in New Issue
Block a user