[Kernel] Add more dtype support for GGUF kernels (#14043)
Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com> Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
This commit is contained in:
@@ -22,7 +22,7 @@ def get_gguf_sample_tensors(
|
||||
return GGUFReader(sample_file).tensors
|
||||
|
||||
|
||||
DTYPES = [torch.half]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
# Hidden_size for testing, must match the sample file in HF repo,
|
||||
# we have `hidden_size = 256, 1024` for test in HF repo currently.
|
||||
HIDDEN_SIZES = [256, 1024]
|
||||
@@ -52,7 +52,7 @@ QUANT_TYPES = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("dtype", [torch.half])
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_dequantize(hidden_size: int, dtype: torch.dtype,
|
||||
@@ -122,7 +122,13 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
ref_output = x @ weight.T
|
||||
|
||||
qweight = torch.tensor(tensor.data, device="cuda")
|
||||
output = ops.ggml_mul_mat_a8(qweight, x, quant_type,
|
||||
qweight.shape[0]).to(dtype)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
output = ops.ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
|
||||
atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
|
||||
# test matrix has inputs centered around 0 and lower precision from
|
||||
# bfloat16 tends to accumulate and can greatly inflate rtol
|
||||
# since outputs are also very close to 0
|
||||
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
|
||||
torch.testing.assert_close(output,
|
||||
ref_output,
|
||||
atol=atols[dtype],
|
||||
rtol=rtols[dtype])
|
||||
|
||||
Reference in New Issue
Block a user