[Kernel][Core] Add AWQ support to the Marlin kernel (#6612)

This commit is contained in:
Alexander Matveev
2024-07-21 19:41:42 -04:00
committed by GitHub
parent 25e778aa16
commit 396d92d5e0
21 changed files with 1594 additions and 276 deletions

View File

@@ -30,7 +30,7 @@ inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin {
namespace marlin_dense {
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
@@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
}
}
} // namespace marlin
} // namespace marlin_dense
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
@@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK(size_k == a.size(1),
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
", size_k = " + str(size_k));
TORCH_CHECK(size_k % marlin::tile_size == 0,
"size_k = " + str(size_k) +
" is not divisible by tile_size = " + str(marlin::tile_size));
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
"size_k = " + str(size_k) + " is not divisible by tile_size = " +
str(marlin_dense::tile_size));
TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = " +
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
", tile_size = " + str(marlin::tile_size));
", tile_size = " + str(marlin_dense::tile_size));
// Verify N
TORCH_CHECK(b_scales.size(1) == size_n,
"b_scales.size(1) = " + str(b_scales.size(1)) +
", size_n = " + str(size_n));
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
" is not divisible by tile_size = " + str(marlin::tile_size));
TORCH_CHECK(
b_q_weight.size(1) % marlin_dense::tile_size == 0,
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
" is not divisible by tile_size = " + str(marlin_dense::tile_size));
int actual_size_n =
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
marlin_dense::pack_factor_4bit;
TORCH_CHECK(
size_n == actual_size_n,
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
@@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"Unexpected groupsize = " + str(groupsize));
// Verify workspace size
TORCH_CHECK(
size_n % marlin::min_thread_n == 0,
"size_n = " + str(size_n) +
", is not divisible by min_thread_n = " + str(marlin::min_thread_n));
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
"size_n = " + str(size_n) +
", is not divisible by min_thread_n = " +
str(marlin_dense::min_thread_n));
int min_workspace_size =
(size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = " + str(workspace.numel()) +
" is below min_workspace_size = " + str(min_workspace_size));
int dev = a.get_device();
marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_m, size_n, size_k,
workspace.data_ptr(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
sms, marlin::max_par);
marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_m, size_n, size_k,
workspace.data_ptr(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_n, sms, marlin_dense::max_par);
return c;
}