[Kernel][Core] Add AWQ support to the Marlin kernel (#6612)
This commit is contained in:
committed by
GitHub
parent
25e778aa16
commit
396d92d5e0
@@ -19,10 +19,10 @@
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#include "../gptq_marlin/gptq_marlin.cuh"
|
||||
#include "../gptq_marlin/gptq_marlin_dtypes.cuh"
|
||||
#include "../gptq_marlin/marlin.cuh"
|
||||
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||
|
||||
using namespace gptq_marlin;
|
||||
using namespace marlin;
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
@@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
||||
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
|
||||
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", marlin::tile_size);
|
||||
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
|
||||
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
||||
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
||||
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
|
||||
" is not divisible by tile_size = ", marlin::tile_size);
|
||||
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
@@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
num_groups = b_scales.size(0);
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(
|
||||
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
|
||||
int min_workspace_size =
|
||||
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
|
||||
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
||||
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
@@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
gptq_marlin::max_par);
|
||||
marlin::max_par);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
|
||||
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
|
||||
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
gptq_marlin::max_par);
|
||||
marlin::max_par);
|
||||
} else {
|
||||
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user