[Kernel] Add more dtype support for GGUF dequantization (#15879)
Signed-off-by: lukas.bluebaum <lukas.bluebaum@aleph-alpha.com>
This commit is contained in:
@@ -117,7 +117,7 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
|
||||
y = x @ weight.T
|
||||
else:
|
||||
# Raise an error if the quantization type is not supported.
|
||||
@@ -377,7 +377,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
x_flat = x.flatten()
|
||||
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
||||
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
||||
x_flat.shape[0]).to(self.params_dtype)
|
||||
x_flat.shape[0], self.params_dtype)
|
||||
return dequant.view(*x.shape, hidden_size)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user