[Perf] Cuda Kernel for Int8 Per Token Group Quant (#21476)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
#include "../per_token_group_quant_8bit.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
@@ -120,7 +122,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s, int64_t group_size,
|
||||
double eps, double min_8bit, double max_8bit,
|
||||
bool scale_ue8m0 = false) {
|
||||
bool scale_ue8m0) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(output_q.is_contiguous());
|
||||
|
||||
@@ -198,6 +200,8 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
|
||||
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
||||
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
|
||||
} else if (dst_type == at::ScalarType::Char) {
|
||||
LAUNCH_KERNEL(scalar_t, int8_t);
|
||||
}
|
||||
}));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user