[CPU][BugFix] Fix loading of w8a8int models with bias (#33582)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
@@ -86,9 +86,14 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
) # Float32 & Bfloat16 variants requires float32 scales
|
||||
scales = scales.view(-1, 1) # Channel-wise scales
|
||||
if layer.bias is not None:
|
||||
layer.bias = layer.bias.to(
|
||||
torch.float32
|
||||
) # Float32 & Bfloat16 variants requires float32 bias
|
||||
# Float32 & Bfloat16 variants requires float32 bias
|
||||
replace_parameter(
|
||||
layer,
|
||||
"bias",
|
||||
torch.nn.Parameter(
|
||||
layer.bias.to(torch.float32), requires_grad=False
|
||||
),
|
||||
)
|
||||
else:
|
||||
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
|
||||
scales = scales.to(torch.bfloat16)
|
||||
|
||||
Reference in New Issue
Block a user