[Kernel] Add more dtype support for GGUF kernels (#14043)

Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
This commit is contained in:
Szymon Ożóg
2025-03-10 15:30:04 +01:00
committed by GitHub
parent b0746fae3d
commit 89cdaa83e7
6 changed files with 318 additions and 266 deletions

View File

@@ -32,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
def get_min_capability(cls) -> int:
@@ -134,6 +134,7 @@ class GGUFLinearMethod(LinearMethodBase):
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
self.params_dtype = params_dtype
output_size_per_partition = sum(output_partition_sizes)
tensor_shape = (output_size_per_partition, input_size_per_partition)
@@ -326,7 +327,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
x_flat = x.flatten()
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0])
x_flat.shape[0]).to(self.params_dtype)
return dequant.view(*x.shape, hidden_size)