[Misc] Disambiguate quantized types via a new ScalarType (#6396)
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
@@ -71,14 +72,15 @@ __global__ void Marlin(
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full) {
|
||||
bool is_k_full, bool has_zp) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
@@ -1963,18 +1965,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
|
||||
void* s, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
int num_bits, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, int num_groups, int group_size, int dev,
|
||||
cudaStream_t stream, int thread_k, int thread_n, int sms,
|
||||
int max_par, bool use_fp32_reduce) {
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||
int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, int max_par, bool use_fp32_reduce) {
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
||||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
// TODO: remove alias when we start supporting other 8bit types
|
||||
int num_bits = q_type.size_bits();
|
||||
int tot_m = prob_m;
|
||||
int tot_m_blocks = div_ceil(tot_m, 16);
|
||||
int pad = 16 * tot_m_blocks - tot_m;
|
||||
@@ -2126,19 +2139,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce) {
|
||||
// Verify num_bits
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int pack_factor = 32 / num_bits;
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
|
||||
b_q_type->str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
b_q_type->str());
|
||||
}
|
||||
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
@@ -2265,21 +2287,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
marlin::marlin_mm_f16i4<half>(
|
||||
marlin::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user