[Kernel][Core] Add AWQ support to the Marlin kernel (#6612)
This commit is contained in:
committed by
GitHub
parent
25e778aa16
commit
396d92d5e0
@@ -30,7 +30,7 @@ inline std::string str(T x) {
|
||||
return std::to_string(x);
|
||||
}
|
||||
|
||||
namespace marlin {
|
||||
namespace marlin_dense {
|
||||
|
||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
@@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace marlin_dense
|
||||
|
||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
@@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
TORCH_CHECK(size_k == a.size(1),
|
||||
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
|
||||
", size_k = " + str(size_k));
|
||||
TORCH_CHECK(size_k % marlin::tile_size == 0,
|
||||
"size_k = " + str(size_k) +
|
||||
" is not divisible by tile_size = " + str(marlin::tile_size));
|
||||
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||
TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
|
||||
"size_k = " + str(size_k) + " is not divisible by tile_size = " +
|
||||
str(marlin_dense::tile_size));
|
||||
TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = " +
|
||||
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
|
||||
", tile_size = " + str(marlin::tile_size));
|
||||
", tile_size = " + str(marlin_dense::tile_size));
|
||||
|
||||
// Verify N
|
||||
TORCH_CHECK(b_scales.size(1) == size_n,
|
||||
"b_scales.size(1) = " + str(b_scales.size(1)) +
|
||||
", size_n = " + str(size_n));
|
||||
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
||||
" is not divisible by tile_size = " + str(marlin::tile_size));
|
||||
TORCH_CHECK(
|
||||
b_q_weight.size(1) % marlin_dense::tile_size == 0,
|
||||
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
||||
" is not divisible by tile_size = " + str(marlin_dense::tile_size));
|
||||
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
|
||||
int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
|
||||
marlin_dense::pack_factor_4bit;
|
||||
TORCH_CHECK(
|
||||
size_n == actual_size_n,
|
||||
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||
@@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
"Unexpected groupsize = " + str(groupsize));
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(
|
||||
size_n % marlin::min_thread_n == 0,
|
||||
"size_n = " + str(size_n) +
|
||||
", is not divisible by min_thread_n = " + str(marlin::min_thread_n));
|
||||
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||
TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
|
||||
"size_n = " + str(size_n) +
|
||||
", is not divisible by min_thread_n = " +
|
||||
str(marlin_dense::min_thread_n));
|
||||
int min_workspace_size =
|
||||
(size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = " + str(workspace.numel()) +
|
||||
" is below min_workspace_size = " + str(min_workspace_size));
|
||||
|
||||
int dev = a.get_device();
|
||||
marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
|
||||
b_scales.data_ptr(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), groupsize, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
|
||||
sms, marlin::max_par);
|
||||
marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
|
||||
b_scales.data_ptr(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), groupsize, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||
thread_n, sms, marlin_dense::max_par);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user