AQLM CUDA support (#3287)
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
15
csrc/ops.h
15
csrc/ops.h
@@ -86,6 +86,21 @@ void gelu_fast(
|
||||
torch::Tensor& input);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias
|
||||
);
|
||||
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes
|
||||
);
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
|
||||
Reference in New Issue
Block a user