[Misc] Disambiguate quantized types via a new ScalarType (#6396)

This commit is contained in:
Lucas Wilkinson
2024-08-02 16:51:58 -04:00
committed by GitHub
parent b482b9a5b1
commit a8d604ca2a
29 changed files with 1111 additions and 356 deletions

View File

@@ -27,6 +27,7 @@
#include <iostream>
#include "common/base.h"
#include "core/scalar_type.hpp"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
@@ -86,7 +87,8 @@ __global__ void Marlin_24(
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n,
int64_t size_k) {
TORCH_CHECK_NOT_IMPLEMENTED(
@@ -1025,13 +1027,14 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n,
int64_t size_k) {
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int pack_factor = 32 / num_bits;
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
int pack_factor = 32 / b_q_type->size_bits();
// Verify M
TORCH_CHECK(size_m == a.size(0),
@@ -1126,8 +1129,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
marlin_24::marlin_cuda_2_4(
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_m, sms, max_par);
b_q_type->size_bits(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
return c;
}