[Kernel] Pass a device pointer into the quantize kernel for the scales (#5159)
This commit is contained in:
committed by
GitHub
parent
0ab278ca31
commit
cbb2f59cc8
@@ -28,9 +28,10 @@ namespace vllm {
|
||||
template <typename scalar_t, typename scale_type>
|
||||
__global__ void static_scaled_int8_quant_kernel(
|
||||
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type scale, const int hidden_size) {
|
||||
const scale_type* scale_ptr, const int hidden_size) {
|
||||
const int tid = threadIdx.x;
|
||||
const int token_idx = blockIdx.x;
|
||||
scale_type scale = *scale_ptr;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
out[token_idx * hidden_size + i] =
|
||||
@@ -39,11 +40,13 @@ __global__ void static_scaled_int8_quant_kernel(
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
float scale) {
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor const& scale) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
dim3 grid(num_tokens);
|
||||
@@ -53,7 +56,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<int8_t>(), scale,
|
||||
hidden_size);
|
||||
out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user