[Perf] Cuda Kernel for Int8 Per Token Group Quant (#21476)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-07-25 20:07:07 -04:00
committed by GitHub
parent 41d3082c41
commit 75d29cf4e1
6 changed files with 47 additions and 3 deletions

View File

@@ -1,6 +1,8 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include "../per_token_group_quant_8bit.h"
#include <cmath>
#include "../../dispatch_utils.h"
@@ -336,3 +338,11 @@ void dynamic_scaled_int8_quant(
}
});
}
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}