[2/n] Migrate per_token_group_quant to torch stable ABI (#36058)

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
mikaylagawarecki
2026-03-25 13:15:13 -04:00
committed by GitHub
parent 1ac2ef2e53
commit bf4cc9ed2d
22 changed files with 207 additions and 133 deletions

View File

@@ -6,15 +6,46 @@
// Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility.
// Note: We register under namespace "_C" so ops are accessible as
// torch.ops._C.<op_name> for compatibility with existing code.
STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) {
STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
#ifndef USE_ROCM
m.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
#endif
#ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor.
// The dummy arguments are here so we can correctly fuse with RMSNorm.
ops.def(
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
"output_s, "
"int group_size, float eps, float fp8_min, float fp8_max, bool "
"scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned "
") -> ()");
// Compute per-token-group 8-bit quantized tensor and UE8M0-packed,
// TMA-aligned scales for DeepGEMM.
ops.def(
"per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, "
"Tensor! output_s_packed, int group_size, float eps, float fp8_min, "
"float fp8_max) -> ()");
// Compute per-token-group INT8 quantized tensor and scaling factor.
ops.def(
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
#endif
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
#ifndef USE_ROCM
m.impl("permute_cols", TORCH_BOX(&permute_cols));
ops.impl("permute_cols", TORCH_BOX(&permute_cols));
#endif
#ifndef USE_ROCM
// Per-token group quantization
ops.impl("per_token_group_fp8_quant", TORCH_BOX(&per_token_group_quant_fp8));
ops.impl("per_token_group_fp8_quant_packed",
TORCH_BOX(&per_token_group_quant_8bit_packed));
ops.impl("per_token_group_quant_int8",
TORCH_BOX(&per_token_group_quant_int8));
#endif
}