[2/n] Migrate per_token_group_quant to torch stable ABI (#36058)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
@@ -653,34 +653,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Compute per-token-group FP8 quantized tensor and scaling factor.
|
||||
// The dummy arguments are here so we can correctly fuse with RMSNorm.
|
||||
ops.def(
|
||||
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
|
||||
"output_s, "
|
||||
"int group_size, float eps, float fp8_min, float fp8_max, bool "
|
||||
"scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned "
|
||||
") -> ()");
|
||||
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
|
||||
&per_token_group_quant_fp8);
|
||||
|
||||
// Compute per-token-group 8-bit quantized tensor and UE8M0-packed,
|
||||
// TMA-aligned scales for DeepGEMM.
|
||||
ops.def(
|
||||
"per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, "
|
||||
"Tensor! output_s_packed, int group_size, float eps, float fp8_min, "
|
||||
"float fp8_max) -> ()");
|
||||
ops.impl("per_token_group_fp8_quant_packed", torch::kCUDA,
|
||||
&per_token_group_quant_8bit_packed);
|
||||
|
||||
// Compute per-token-group INT8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
|
||||
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||
"()");
|
||||
ops.impl("per_token_group_quant_int8", torch::kCUDA,
|
||||
&per_token_group_quant_int8);
|
||||
|
||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||
ops.def(
|
||||
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
||||
|
||||
Reference in New Issue
Block a user