From bf4cc9ed2d1f2075a1dbdeed233a9a44a7d48247 Mon Sep 17 00:00:00 2001 From: mikaylagawarecki Date: Wed, 25 Mar 2026 13:15:13 -0400 Subject: [PATCH] [2/n] Migrate per_token_group_quant to torch stable ABI (#36058) Signed-off-by: Mikayla Gawarecki --- CMakeLists.txt | 9 +- csrc/cache_kernels.cu | 3 +- csrc/layernorm_kernels.cu | 2 +- csrc/layernorm_quant_kernels.cu | 2 +- csrc/libtorch_stable/dispatch_utils.h | 60 ++++++++++++ csrc/libtorch_stable/ops.h | 21 ++++ .../quantization/vectorization.cuh | 4 +- .../quantization/vectorization_utils.cuh | 0 .../w8a8/fp8/per_token_group_quant.cu | 96 ++++++++++--------- .../w8a8/int8/per_token_group_quant.cu | 12 +++ .../w8a8/per_token_group_quant_8bit.h | 10 ++ csrc/libtorch_stable/torch_bindings.cpp | 39 +++++++- csrc/libtorch_stable/torch_utils.h | 4 +- csrc/ops.h | 19 ---- .../fused_kernels/layernorm_utils.cuh | 2 +- .../fused_kernels/quant_conversions.cuh | 2 +- csrc/quantization/w8a8/fp8/common.cu | 2 +- csrc/quantization/w8a8/fp8/common.cuh | 2 +- .../w8a8/int8/per_token_group_quant.cu | 12 --- csrc/quantization/w8a8/int8/scaled_quant.cu | 2 +- .../w8a8/per_token_group_quant_8bit.h | 9 -- csrc/torch_bindings.cpp | 28 ------ 22 files changed, 207 insertions(+), 133 deletions(-) create mode 100644 csrc/libtorch_stable/dispatch_utils.h rename csrc/{ => libtorch_stable}/quantization/vectorization.cuh (88%) rename csrc/{ => libtorch_stable}/quantization/vectorization_utils.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/w8a8/fp8/per_token_group_quant.cu (83%) create mode 100644 csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu create mode 100644 csrc/libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h delete mode 100644 csrc/quantization/w8a8/int8/per_token_group_quant.cu delete mode 100644 csrc/quantization/w8a8/per_token_group_quant_8bit.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 202eb2b4c..afc02f7fb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -343,9 +343,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" - "csrc/cutlass_extensions/common.cpp" - "csrc/quantization/w8a8/fp8/per_token_group_quant.cu" - "csrc/quantization/w8a8/int8/per_token_group_quant.cu") + "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -969,7 +967,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/libtorch_stable/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/permute_cols.cu") + list(APPEND VLLM_STABLE_EXT_SRC + "csrc/libtorch_stable/permute_cols.cu" + "csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu" + "csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 4b07f9b53..2b3906df9 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -7,7 +7,8 @@ #include "cuda_utils.h" #include "cuda_compat.h" #include "dispatch_utils.h" -#include "quantization/vectorization_utils.cuh" + +#include "libtorch_stable/quantization/vectorization_utils.cuh" #include "concat_mla_q.cuh" #ifdef USE_ROCM diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index dfc67b933..9766103f7 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -2,7 +2,7 @@ #include "dispatch_utils.h" #include "cub_helpers.h" #include "core/batch_invariant.hpp" -#include "quantization/vectorization_utils.cuh" +#include "libtorch_stable/quantization/vectorization_utils.cuh" #include #include diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0880b8d50..f96386252 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -10,7 +10,7 @@ #include "dispatch_utils.h" #include "cub_helpers.h" #include "core/batch_invariant.hpp" -#include "quantization/vectorization_utils.cuh" +#include "libtorch_stable/quantization/vectorization_utils.cuh" #include #include diff --git a/csrc/libtorch_stable/dispatch_utils.h b/csrc/libtorch_stable/dispatch_utils.h new file mode 100644 index 000000000..5ebba72b1 --- /dev/null +++ b/csrc/libtorch_stable/dispatch_utils.h @@ -0,0 +1,60 @@ +/* + * Stable ABI compatible dispatch utilities for vLLM. + * Adapted from dispatch_utils.h to use PyTorch's header-only (THO_*) macros + * instead of the ATen (AT_*) macros. + * + * These macros use: + * - THO_DISPATCH_SWITCH instead of AT_DISPATCH_SWITCH + * - THO_DISPATCH_CASE instead of AT_DISPATCH_CASE + * - torch::headeronly::ScalarType instead of at::ScalarType + * + * Add more macros here as needed when migrating additional kernels. + */ +#pragma once + +#include +#include +#include + +// Need a special dispatch case macro since we will nest the FP8 dispatch. +// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'. +#define VLLM_STABLE_DISPATCH_FP8_CASE(enum_type, ...) \ + THO_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__) + +#define VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(...) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float, __VA_ARGS__) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_STABLE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + THO_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +// FP8 type dispatch - ROCm uses FNUZ format, CUDA uses OCP format +#ifdef USE_ROCM + #define VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(...) \ + VLLM_STABLE_DISPATCH_FP8_CASE( \ + torch::headeronly::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + VLLM_STABLE_DISPATCH_FP8_CASE( \ + torch::headeronly::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) +#else + #define VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(...) \ + VLLM_STABLE_DISPATCH_FP8_CASE( \ + torch::headeronly::ScalarType::Float8_e4m3fn, __VA_ARGS__) +#endif + +// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'. +// See VLLM_STABLE_DISPATCH_FP8_CASE above. +#define VLLM_STABLE_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ + THO_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) + +// Boolean dispatch +#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \ + if (expr) { \ + constexpr bool const_expr = true; \ + __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + __VA_ARGS__(); \ + } diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index 5fe1492b8..b74c5c505 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -6,4 +6,25 @@ #ifndef USE_ROCM torch::stable::Tensor permute_cols(torch::stable::Tensor const& A, torch::stable::Tensor const& perm); + +void per_token_group_quant_fp8(const torch::stable::Tensor& input, + torch::stable::Tensor& output_q, + torch::stable::Tensor& output_s, + int64_t group_size, double eps, double fp8_min, + double fp8_max, bool scale_ue8m0, + bool dummy_is_scale_transposed, + bool dummy_is_tma_aligned); + +// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales. +void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input, + torch::stable::Tensor& output_q, + torch::stable::Tensor& output_s_packed, + int64_t group_size, double eps, + double min_8bit, double max_8bit); + +void per_token_group_quant_int8(const torch::stable::Tensor& input, + torch::stable::Tensor& output_q, + torch::stable::Tensor& output_s, + int64_t group_size, double eps, double int8_min, + double int8_max); #endif diff --git a/csrc/quantization/vectorization.cuh b/csrc/libtorch_stable/quantization/vectorization.cuh similarity index 88% rename from csrc/quantization/vectorization.cuh rename to csrc/libtorch_stable/quantization/vectorization.cuh index 11d57a5fa..9d5eea00e 100644 --- a/csrc/quantization/vectorization.cuh +++ b/csrc/libtorch_stable/quantization/vectorization.cuh @@ -4,8 +4,8 @@ */ // Include both AMD and NVIDIA fp8 types to avoid circular import -#include -#include +#include +#include namespace vllm { diff --git a/csrc/quantization/vectorization_utils.cuh b/csrc/libtorch_stable/quantization/vectorization_utils.cuh similarity index 100% rename from csrc/quantization/vectorization_utils.cuh rename to csrc/libtorch_stable/quantization/vectorization_utils.cuh diff --git a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu b/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu similarity index 83% rename from csrc/quantization/w8a8/fp8/per_token_group_quant.cu rename to csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu index 5174625ad..69b6564be 100644 --- a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu +++ b/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu @@ -1,16 +1,18 @@ -#include +#include +#include +#include +#include -#include "quantization/w8a8/per_token_group_quant_8bit.h" +#include "libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h" #include #include -#include - -#include "quantization/vectorization.cuh" -#include "quantization/vectorization_utils.cuh" -#include "dispatch_utils.h" +#include "libtorch_stable/quantization/vectorization.cuh" +#include "libtorch_stable/quantization/vectorization_utils.cuh" +#include "libtorch_stable/dispatch_utils.h" +#include "libtorch_stable/torch_utils.h" __device__ __forceinline__ float GroupReduceMax(float val) { unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; @@ -154,20 +156,20 @@ inline int GetGroupsPerBlock(int64_t num_groups) { return 1; } -void per_token_group_quant_8bit(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s, int64_t group_size, - double eps, double min_8bit, double max_8bit, - bool scale_ue8m0) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(output_q.is_contiguous()); +void per_token_group_quant_8bit(const torch::stable::Tensor& input, + torch::stable::Tensor& output_q, + torch::stable::Tensor& output_s, + int64_t group_size, double eps, double min_8bit, + double max_8bit, bool scale_ue8m0) { + STD_TORCH_CHECK(input.is_contiguous()); + STD_TORCH_CHECK(output_q.is_contiguous()); const int num_groups = input.numel() / group_size; - TORCH_CHECK(input.numel() % group_size == 0); - TORCH_CHECK(output_s.dim() == 2); + STD_TORCH_CHECK(input.numel() % group_size == 0); + STD_TORCH_CHECK(output_s.dim() == 2); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = get_current_cuda_stream(); constexpr int THREADS_PER_GROUP = 16; @@ -222,11 +224,11 @@ void per_token_group_quant_8bit(const torch::Tensor& input, } \ } while (0) - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "per_token_group_quant_8bit", ([&] { - if (dst_type == at::ScalarType::Float8_e4m3fn) { + if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) { LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3); - } else if (dst_type == at::ScalarType::Char) { + } else if (dst_type == torch::headeronly::ScalarType::Char) { LAUNCH_KERNEL(scalar_t, int8_t); } })); @@ -294,41 +296,42 @@ __global__ void per_token_group_quant_8bit_packed_kernel( threads_per_group, y_s, min_8bit, max_8bit); } -void per_token_group_quant_8bit_packed(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s_packed, +void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input, + torch::stable::Tensor& output_q, + torch::stable::Tensor& output_s_packed, int64_t group_size, double eps, double min_8bit, double max_8bit) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(output_q.is_contiguous()); + STD_TORCH_CHECK(input.is_contiguous()); + STD_TORCH_CHECK(output_q.is_contiguous()); const int64_t k = input.size(-1); - TORCH_CHECK(k % group_size == 0, "Last dimension (", k, - ") must be divisible by group_size (", group_size, ")."); + STD_TORCH_CHECK(k % group_size == 0, "Last dimension (", k, + ") must be divisible by group_size (", group_size, ")."); const int64_t mn = input.numel() / k; const int64_t groups_per_row = k / group_size; const int64_t num_groups = mn * groups_per_row; - TORCH_CHECK(output_s_packed.dim() == 2, - "output_s_packed must be 2D, got dim=", output_s_packed.dim(), - "."); + STD_TORCH_CHECK(output_s_packed.dim() == 2, + "output_s_packed must be 2D, got dim=", output_s_packed.dim(), + "."); const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4; const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4; - TORCH_CHECK(output_s_packed.scalar_type() == at::ScalarType::Int, - "output_s_packed must have dtype int32 for UE8M0-packed scales."); + STD_TORCH_CHECK( + output_s_packed.scalar_type() == torch::headeronly::ScalarType::Int, + "output_s_packed must have dtype int32 for UE8M0-packed scales."); // DeepGEMM expects SFA scales in MN-major form with shape // [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last // dimension. - TORCH_CHECK(output_s_packed.size(0) == mn && - output_s_packed.size(1) == k_num_packed_sfk, - "output_s_packed shape must be [", mn, ", ", k_num_packed_sfk, - "], but got [", output_s_packed.size(0), ", ", - output_s_packed.size(1), "]."); + STD_TORCH_CHECK(output_s_packed.size(0) == mn && + output_s_packed.size(1) == k_num_packed_sfk, + "output_s_packed shape must be [", mn, ", ", k_num_packed_sfk, + "], but got [", output_s_packed.size(0), ", ", + output_s_packed.size(1), "]."); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = get_current_cuda_stream(); constexpr int THREADS_PER_GROUP = 16; @@ -340,7 +343,7 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input, // zero-initialize packed scales, since we use atomicOr to accumulate // exponents from different groups. - output_s_packed.zero_(); + torch::stable::zero_(output_s_packed); #define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \ do { \ @@ -359,14 +362,14 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input, static_cast(max_8bit)); \ } while (0) - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_STABLE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] { - if (dst_type == at::ScalarType::Float8_e4m3fn) { + if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) { LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3); - } else if (dst_type == at::ScalarType::Char) { + } else if (dst_type == torch::headeronly::ScalarType::Char) { LAUNCH_PACKED_KERNEL(scalar_t, int8_t); } else { - TORCH_CHECK( + STD_TORCH_CHECK( false, "per_token_group_quant_8bit_packed only supports FP8/INT8 " "outputs."); @@ -376,12 +379,13 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input, #undef LAUNCH_PACKED_KERNEL } -void per_token_group_quant_fp8(const torch::Tensor& input, - torch::Tensor& output_q, torch::Tensor& output_s, +void per_token_group_quant_fp8(const torch::stable::Tensor& input, + torch::stable::Tensor& output_q, + torch::stable::Tensor& output_s, int64_t group_size, double eps, double fp8_min, double fp8_max, bool scale_ue8m0, bool dummy_is_scale_transposed = false, bool dummy_is_tma_aligned = false) { per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0); -} \ No newline at end of file +} diff --git a/csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu b/csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu new file mode 100644 index 000000000..2ffbee7ca --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu @@ -0,0 +1,12 @@ +#include + +#include "libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h" + +void per_token_group_quant_int8(const torch::stable::Tensor& input, + torch::stable::Tensor& output_q, + torch::stable::Tensor& output_s, + int64_t group_size, double eps, double int8_min, + double int8_max) { + per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, + int8_min, int8_max); +} diff --git a/csrc/libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h b/csrc/libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h new file mode 100644 index 000000000..d67fd2b33 --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +// 8-bit per-token-group quantization helper used by both FP8 and INT8 +void per_token_group_quant_8bit(const torch::stable::Tensor& input, + torch::stable::Tensor& output_q, + torch::stable::Tensor& output_s, + int64_t group_size, double eps, double min_8bit, + double max_8bit, bool scale_ue8m0 = false); diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp index 0c0ecaa01..d3b4c395b 100644 --- a/csrc/libtorch_stable/torch_bindings.cpp +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -6,15 +6,46 @@ // Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility. // Note: We register under namespace "_C" so ops are accessible as // torch.ops._C. for compatibility with existing code. -STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { #ifndef USE_ROCM - m.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); + ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); +#endif + +#ifndef USE_ROCM + // Compute per-token-group FP8 quantized tensor and scaling factor. + // The dummy arguments are here so we can correctly fuse with RMSNorm. + ops.def( + "per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! " + "output_s, " + "int group_size, float eps, float fp8_min, float fp8_max, bool " + "scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned " + ") -> ()"); + // Compute per-token-group 8-bit quantized tensor and UE8M0-packed, + // TMA-aligned scales for DeepGEMM. + ops.def( + "per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, " + "Tensor! output_s_packed, int group_size, float eps, float fp8_min, " + "float fp8_max) -> ()"); + // Compute per-token-group INT8 quantized tensor and scaling factor. + ops.def( + "per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! " + "output_s, int group_size, float eps, float int8_min, float int8_max) -> " + "()"); #endif } -STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { #ifndef USE_ROCM - m.impl("permute_cols", TORCH_BOX(&permute_cols)); + ops.impl("permute_cols", TORCH_BOX(&permute_cols)); +#endif + +#ifndef USE_ROCM + // Per-token group quantization + ops.impl("per_token_group_fp8_quant", TORCH_BOX(&per_token_group_quant_fp8)); + ops.impl("per_token_group_fp8_quant_packed", + TORCH_BOX(&per_token_group_quant_8bit_packed)); + ops.impl("per_token_group_quant_int8", + TORCH_BOX(&per_token_group_quant_int8)); #endif } diff --git a/csrc/libtorch_stable/torch_utils.h b/csrc/libtorch_stable/torch_utils.h index a615768a9..1bc744fee 100644 --- a/csrc/libtorch_stable/torch_utils.h +++ b/csrc/libtorch_stable/torch_utils.h @@ -1,11 +1,13 @@ #pragma once #include +#include + #include // Utility to get the current CUDA stream for a given device using stable APIs. // Returns a cudaStream_t for use in kernel launches. -inline cudaStream_t get_current_cuda_stream(int32_t device_index) { +inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) { void* stream_ptr = nullptr; TORCH_ERROR_CODE_CHECK( aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); diff --git a/csrc/ops.h b/csrc/ops.h index ceb8e021c..2e16ef877 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -306,25 +306,6 @@ void silu_and_mul_scaled_fp4_experts_quant( torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts); -void per_token_group_quant_fp8(const torch::Tensor& input, - torch::Tensor& output_q, torch::Tensor& output_s, - int64_t group_size, double eps, double fp8_min, - double fp8_max, bool scale_ue8m0, - bool dummy_is_scale_transposed, - bool dummy_is_tma_aligned); - -void per_token_group_quant_int8(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s, int64_t group_size, - double eps, double int8_min, double int8_max); - -// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales. -void per_token_group_quant_8bit_packed(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s_packed, - int64_t group_size, double eps, - double min_8bit, double max_8bit); - #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 1f0d58352..48b615ebd 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -4,7 +4,7 @@ * __device__ layernorm utilities. */ -#include "quantization/vectorization.cuh" +#include "libtorch_stable/quantization/vectorization.cuh" #include "quantization/utils.cuh" #include "quant_conversions.cuh" diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index 2b1eb1d56..3711c47ed 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -4,7 +4,7 @@ * __device__ helper functions to deal with float -> quant datatype conversion */ -#include "quantization/vectorization.cuh" +#include "libtorch_stable/quantization/vectorization.cuh" // TODO(luka/varun):refactor common.cuh to use this file instead #include "quantization/w8a8/fp8/common.cuh" diff --git a/csrc/quantization/w8a8/fp8/common.cu b/csrc/quantization/w8a8/fp8/common.cu index d07cdd571..52e159d65 100644 --- a/csrc/quantization/w8a8/fp8/common.cu +++ b/csrc/quantization/w8a8/fp8/common.cu @@ -1,7 +1,7 @@ #include "common.cuh" #include "dispatch_utils.h" #include "cub_helpers.h" -#include "quantization/vectorization_utils.cuh" +#include "libtorch_stable/quantization/vectorization_utils.cuh" #include #include #include diff --git a/csrc/quantization/w8a8/fp8/common.cuh b/csrc/quantization/w8a8/fp8/common.cuh index 7838f211c..7a385f516 100644 --- a/csrc/quantization/w8a8/fp8/common.cuh +++ b/csrc/quantization/w8a8/fp8/common.cuh @@ -1,6 +1,6 @@ #pragma once -#include "quantization/vectorization.cuh" +#include "libtorch_stable/quantization/vectorization.cuh" #include "quantization/utils.cuh" #include diff --git a/csrc/quantization/w8a8/int8/per_token_group_quant.cu b/csrc/quantization/w8a8/int8/per_token_group_quant.cu deleted file mode 100644 index 9d808a176..000000000 --- a/csrc/quantization/w8a8/int8/per_token_group_quant.cu +++ /dev/null @@ -1,12 +0,0 @@ -#include -#include - -#include "quantization/w8a8/per_token_group_quant_8bit.h" - -void per_token_group_quant_int8(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s, int64_t group_size, - double eps, double int8_min, double int8_max) { - per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, - int8_min, int8_max); -} \ No newline at end of file diff --git a/csrc/quantization/w8a8/int8/scaled_quant.cu b/csrc/quantization/w8a8/int8/scaled_quant.cu index be8ecfeac..ae1395a36 100644 --- a/csrc/quantization/w8a8/int8/scaled_quant.cu +++ b/csrc/quantization/w8a8/int8/scaled_quant.cu @@ -5,7 +5,7 @@ #include #include "dispatch_utils.h" -#include "quantization/vectorization_utils.cuh" +#include "libtorch_stable/quantization/vectorization_utils.cuh" #include "cub_helpers.h" static inline __device__ int8_t float_to_int8_rn(float x) { diff --git a/csrc/quantization/w8a8/per_token_group_quant_8bit.h b/csrc/quantization/w8a8/per_token_group_quant_8bit.h deleted file mode 100644 index 25d4ecd11..000000000 --- a/csrc/quantization/w8a8/per_token_group_quant_8bit.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once -#include - -// 8-bit per-token-group quantization helper used by both FP8 and INT8 -void per_token_group_quant_8bit(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s, int64_t group_size, - double eps, double min_8bit, double max_8bit, - bool scale_ue8m0 = false); \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 5892703a8..3bc69c7bb 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -653,34 +653,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); #ifndef USE_ROCM - // Compute per-token-group FP8 quantized tensor and scaling factor. - // The dummy arguments are here so we can correctly fuse with RMSNorm. - ops.def( - "per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! " - "output_s, " - "int group_size, float eps, float fp8_min, float fp8_max, bool " - "scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned " - ") -> ()"); - ops.impl("per_token_group_fp8_quant", torch::kCUDA, - &per_token_group_quant_fp8); - - // Compute per-token-group 8-bit quantized tensor and UE8M0-packed, - // TMA-aligned scales for DeepGEMM. - ops.def( - "per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, " - "Tensor! output_s_packed, int group_size, float eps, float fp8_min, " - "float fp8_max) -> ()"); - ops.impl("per_token_group_fp8_quant_packed", torch::kCUDA, - &per_token_group_quant_8bit_packed); - - // Compute per-token-group INT8 quantized tensor and scaling factor. - ops.def( - "per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! " - "output_s, int group_size, float eps, float int8_min, float int8_max) -> " - "()"); - ops.impl("per_token_group_quant_int8", torch::kCUDA, - &per_token_group_quant_int8); - // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel ops.def( "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "