[Kernel] Add more dtype support for GGUF kernels (#14043)

Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
This commit is contained in:
Szymon Ożóg
2025-03-10 15:30:04 +01:00
committed by GitHub
parent b0746fae3d
commit 89cdaa83e7
6 changed files with 318 additions and 266 deletions

View File

@@ -22,7 +22,7 @@ def get_gguf_sample_tensors(
return GGUFReader(sample_file).tensors
DTYPES = [torch.half]
DTYPES = [torch.half, torch.bfloat16, torch.float32]
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
HIDDEN_SIZES = [256, 1024]
@@ -52,7 +52,7 @@ QUANT_TYPES = [
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("dtype", [torch.half])
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_dequantize(hidden_size: int, dtype: torch.dtype,
@@ -122,7 +122,13 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
output = ops.ggml_mul_mat_a8(qweight, x, quant_type,
qweight.shape[0]).to(dtype)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
output = ops.ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
# test matrix has inputs centered around 0 and lower precision from
# bfloat16 tends to accumulate and can greatly inflate rtol
# since outputs are also very close to 0
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
torch.testing.assert_close(output,
ref_output,
atol=atols[dtype],
rtol=rtols[dtype])