[Perf] Cuda Kernel for Int8 Per Token Group Quant (#21476)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-07-25 20:07:07 -04:00
committed by GitHub
parent 41d3082c41
commit 75d29cf4e1
6 changed files with 47 additions and 3 deletions

View File

@@ -1,6 +1,8 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include "../per_token_group_quant_8bit.h"
#include <cmath>
#include <cuda_fp16.h>
@@ -120,7 +122,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double min_8bit, double max_8bit,
bool scale_ue8m0 = false) {
bool scale_ue8m0) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(output_q.is_contiguous());
@@ -198,6 +200,8 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
} else if (dst_type == at::ScalarType::Char) {
LAUNCH_KERNEL(scalar_t, int8_t);
}
}));