[Misc] Disambiguate quantized types via a new ScalarType (#6396)
This commit is contained in:
@@ -27,6 +27,7 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "common/base.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
@@ -86,7 +87,8 @@ __global__ void Marlin_24(
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
@@ -1025,13 +1027,14 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k) {
|
||||
// Verify num_bits
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int pack_factor = 32 / num_bits;
|
||||
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
|
||||
// Verify M
|
||||
TORCH_CHECK(size_m == a.size(0),
|
||||
@@ -1126,8 +1129,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
marlin_24::marlin_cuda_2_4(
|
||||
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
|
||||
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
|
||||
num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||
thread_m, sms, max_par);
|
||||
b_q_type->size_bits(), groupsize, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user