[Kernel] Add more dtype support for GGUF dequantization (#15879)

Signed-off-by: lukas.bluebaum <lukas.bluebaum@aleph-alpha.com>
This commit is contained in:
LukasBluebaum
2025-04-02 10:58:48 +02:00
committed by GitHub
parent 101f1481f9
commit 90969fb39a
9 changed files with 80 additions and 50 deletions

View File

@@ -117,7 +117,7 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
y = x @ weight.T
else:
# Raise an error if the quantization type is not supported.
@@ -377,7 +377,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
x_flat = x.flatten()
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0]).to(self.params_dtype)
x_flat.shape[0], self.params_dtype)
return dequant.view(*x.shape, hidden_size)