[Kernel] Add more dtype support for GGUF kernels (#14043)
Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com> Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
This commit is contained in:
@@ -32,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -134,6 +134,7 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
self.params_dtype = params_dtype
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
||||
@@ -326,7 +327,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
x_flat = x.flatten()
|
||||
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
||||
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
||||
x_flat.shape[0])
|
||||
x_flat.shape[0]).to(self.params_dtype)
|
||||
return dequant.view(*x.shape, hidden_size)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user