[AMD] Add support for GGUF quantization on ROCm (#10254)

This commit is contained in:
kliuae
2024-11-23 13:14:49 +08:00
committed by GitHub
parent 02a43f82a9
commit 7c25fe45a6
11 changed files with 234 additions and 211 deletions

View File

@@ -4,6 +4,8 @@
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
@@ -32,8 +34,8 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
amax = fmaxf(amax, VLLM_SHFL_XOR_SYNC_WIDTH(amax, mask, 32));
sum += VLLM_SHFL_XOR_SYNC_WIDTH(sum, mask, 32);
}
const float d = amax / 127;