[2/n] Migrate per_token_group_quant to torch stable ABI (#36058)

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
mikaylagawarecki
2026-03-25 13:15:13 -04:00
committed by GitHub
parent 1ac2ef2e53
commit bf4cc9ed2d
22 changed files with 207 additions and 133 deletions

View File

@@ -343,9 +343,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp" "csrc/cutlass_extensions/common.cpp")
"csrc/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/quantization/w8a8/int8/per_token_group_quant.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}" SRCS "${VLLM_EXT_SRC}"
@@ -969,7 +967,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/libtorch_stable/torch_bindings.cpp") "csrc/libtorch_stable/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/permute_cols.cu") list(APPEND VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/permute_cols.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu")
endif() endif()
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")

View File

@@ -7,7 +7,8 @@
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh"
#include "libtorch_stable/quantization/vectorization_utils.cuh"
#include "concat_mla_q.cuh" #include "concat_mla_q.cuh"
#ifdef USE_ROCM #ifdef USE_ROCM

View File

@@ -2,7 +2,7 @@
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "cub_helpers.h" #include "cub_helpers.h"
#include "core/batch_invariant.hpp" #include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh" #include "libtorch_stable/quantization/vectorization_utils.cuh"
#include <torch/cuda.h> #include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>

View File

@@ -10,7 +10,7 @@
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "cub_helpers.h" #include "cub_helpers.h"
#include "core/batch_invariant.hpp" #include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh" #include "libtorch_stable/quantization/vectorization_utils.cuh"
#include <torch/cuda.h> #include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>

View File

@@ -0,0 +1,60 @@
/*
* Stable ABI compatible dispatch utilities for vLLM.
* Adapted from dispatch_utils.h to use PyTorch's header-only (THO_*) macros
* instead of the ATen (AT_*) macros.
*
* These macros use:
* - THO_DISPATCH_SWITCH instead of AT_DISPATCH_SWITCH
* - THO_DISPATCH_CASE instead of AT_DISPATCH_CASE
* - torch::headeronly::ScalarType instead of at::ScalarType
*
* Add more macros here as needed when migrating additional kernels.
*/
#pragma once
#include <torch/headeronly/core/Dispatch.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/util/Exception.h>
// Need a special dispatch case macro since we will nest the FP8 dispatch.
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
#define VLLM_STABLE_DISPATCH_FP8_CASE(enum_type, ...) \
THO_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
#define VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_STABLE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// FP8 type dispatch - ROCm uses FNUZ format, CUDA uses OCP format
#ifdef USE_ROCM
#define VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(...) \
VLLM_STABLE_DISPATCH_FP8_CASE( \
torch::headeronly::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
VLLM_STABLE_DISPATCH_FP8_CASE( \
torch::headeronly::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
#else
#define VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(...) \
VLLM_STABLE_DISPATCH_FP8_CASE( \
torch::headeronly::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#endif
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
// See VLLM_STABLE_DISPATCH_FP8_CASE above.
#define VLLM_STABLE_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
// Boolean dispatch
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \
constexpr bool const_expr = true; \
__VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
__VA_ARGS__(); \
}

View File

@@ -6,4 +6,25 @@
#ifndef USE_ROCM #ifndef USE_ROCM
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A, torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
torch::stable::Tensor const& perm); torch::stable::Tensor const& perm);
void per_token_group_quant_fp8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0,
bool dummy_is_scale_transposed,
bool dummy_is_tma_aligned);
// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales.
void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s_packed,
int64_t group_size, double eps,
double min_8bit, double max_8bit);
void per_token_group_quant_int8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s,
int64_t group_size, double eps, double int8_min,
double int8_max);
#endif #endif

View File

@@ -4,8 +4,8 @@
*/ */
// Include both AMD and NVIDIA fp8 types to avoid circular import // Include both AMD and NVIDIA fp8 types to avoid circular import
#include <c10/util/Float8_e4m3fnuz.h> #include <torch/headeronly/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h> #include <torch/headeronly/util/Float8_e4m3fn.h>
namespace vllm { namespace vllm {

View File

@@ -1,16 +1,18 @@
#include <ATen/cuda/CUDAContext.h> #include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/core/ScalarType.h>
#include "quantization/w8a8/per_token_group_quant_8bit.h" #include "libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h"
#include <cmath> #include <cmath>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <torch/all.h> #include "libtorch_stable/quantization/vectorization.cuh"
#include "libtorch_stable/quantization/vectorization_utils.cuh"
#include "quantization/vectorization.cuh" #include "libtorch_stable/dispatch_utils.h"
#include "quantization/vectorization_utils.cuh" #include "libtorch_stable/torch_utils.h"
#include "dispatch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val) { __device__ __forceinline__ float GroupReduceMax(float val) {
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
@@ -154,20 +156,20 @@ inline int GetGroupsPerBlock(int64_t num_groups) {
return 1; return 1;
} }
void per_token_group_quant_8bit(const torch::Tensor& input, void per_token_group_quant_8bit(const torch::stable::Tensor& input,
torch::Tensor& output_q, torch::stable::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size, torch::stable::Tensor& output_s,
double eps, double min_8bit, double max_8bit, int64_t group_size, double eps, double min_8bit,
bool scale_ue8m0) { double max_8bit, bool scale_ue8m0) {
TORCH_CHECK(input.is_contiguous()); STD_TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(output_q.is_contiguous()); STD_TORCH_CHECK(output_q.is_contiguous());
const int num_groups = input.numel() / group_size; const int num_groups = input.numel() / group_size;
TORCH_CHECK(input.numel() % group_size == 0); STD_TORCH_CHECK(input.numel() % group_size == 0);
TORCH_CHECK(output_s.dim() == 2); STD_TORCH_CHECK(output_s.dim() == 2);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = get_current_cuda_stream();
constexpr int THREADS_PER_GROUP = 16; constexpr int THREADS_PER_GROUP = 16;
@@ -222,11 +224,11 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
} \ } \
} while (0) } while (0)
VLLM_DISPATCH_FLOATING_TYPES( VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "per_token_group_quant_8bit", ([&] { input.scalar_type(), "per_token_group_quant_8bit", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) { if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3); LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
} else if (dst_type == at::ScalarType::Char) { } else if (dst_type == torch::headeronly::ScalarType::Char) {
LAUNCH_KERNEL(scalar_t, int8_t); LAUNCH_KERNEL(scalar_t, int8_t);
} }
})); }));
@@ -294,41 +296,42 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
threads_per_group, y_s, min_8bit, max_8bit); threads_per_group, y_s, min_8bit, max_8bit);
} }
void per_token_group_quant_8bit_packed(const torch::Tensor& input, void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
torch::Tensor& output_q, torch::stable::Tensor& output_q,
torch::Tensor& output_s_packed, torch::stable::Tensor& output_s_packed,
int64_t group_size, double eps, int64_t group_size, double eps,
double min_8bit, double max_8bit) { double min_8bit, double max_8bit) {
TORCH_CHECK(input.is_contiguous()); STD_TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(output_q.is_contiguous()); STD_TORCH_CHECK(output_q.is_contiguous());
const int64_t k = input.size(-1); const int64_t k = input.size(-1);
TORCH_CHECK(k % group_size == 0, "Last dimension (", k, STD_TORCH_CHECK(k % group_size == 0, "Last dimension (", k,
") must be divisible by group_size (", group_size, ")."); ") must be divisible by group_size (", group_size, ").");
const int64_t mn = input.numel() / k; const int64_t mn = input.numel() / k;
const int64_t groups_per_row = k / group_size; const int64_t groups_per_row = k / group_size;
const int64_t num_groups = mn * groups_per_row; const int64_t num_groups = mn * groups_per_row;
TORCH_CHECK(output_s_packed.dim() == 2, STD_TORCH_CHECK(output_s_packed.dim() == 2,
"output_s_packed must be 2D, got dim=", output_s_packed.dim(), "output_s_packed must be 2D, got dim=", output_s_packed.dim(),
"."); ".");
const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4; const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4;
const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4; const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4;
TORCH_CHECK(output_s_packed.scalar_type() == at::ScalarType::Int, STD_TORCH_CHECK(
"output_s_packed must have dtype int32 for UE8M0-packed scales."); output_s_packed.scalar_type() == torch::headeronly::ScalarType::Int,
"output_s_packed must have dtype int32 for UE8M0-packed scales.");
// DeepGEMM expects SFA scales in MN-major form with shape // DeepGEMM expects SFA scales in MN-major form with shape
// [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last // [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last
// dimension. // dimension.
TORCH_CHECK(output_s_packed.size(0) == mn && STD_TORCH_CHECK(output_s_packed.size(0) == mn &&
output_s_packed.size(1) == k_num_packed_sfk, output_s_packed.size(1) == k_num_packed_sfk,
"output_s_packed shape must be [", mn, ", ", k_num_packed_sfk, "output_s_packed shape must be [", mn, ", ", k_num_packed_sfk,
"], but got [", output_s_packed.size(0), ", ", "], but got [", output_s_packed.size(0), ", ",
output_s_packed.size(1), "]."); output_s_packed.size(1), "].");
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = get_current_cuda_stream();
constexpr int THREADS_PER_GROUP = 16; constexpr int THREADS_PER_GROUP = 16;
@@ -340,7 +343,7 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
// zero-initialize packed scales, since we use atomicOr to accumulate // zero-initialize packed scales, since we use atomicOr to accumulate
// exponents from different groups. // exponents from different groups.
output_s_packed.zero_(); torch::stable::zero_(output_s_packed);
#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \ #define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \
do { \ do { \
@@ -359,14 +362,14 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
static_cast<float>(max_8bit)); \ static_cast<float>(max_8bit)); \
} while (0) } while (0)
VLLM_DISPATCH_FLOATING_TYPES( VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] { input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) { if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) {
LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3); LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3);
} else if (dst_type == at::ScalarType::Char) { } else if (dst_type == torch::headeronly::ScalarType::Char) {
LAUNCH_PACKED_KERNEL(scalar_t, int8_t); LAUNCH_PACKED_KERNEL(scalar_t, int8_t);
} else { } else {
TORCH_CHECK( STD_TORCH_CHECK(
false, false,
"per_token_group_quant_8bit_packed only supports FP8/INT8 " "per_token_group_quant_8bit_packed only supports FP8/INT8 "
"outputs."); "outputs.");
@@ -376,12 +379,13 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
#undef LAUNCH_PACKED_KERNEL #undef LAUNCH_PACKED_KERNEL
} }
void per_token_group_quant_fp8(const torch::Tensor& input, void per_token_group_quant_fp8(const torch::stable::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s, torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s,
int64_t group_size, double eps, double fp8_min, int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0, double fp8_max, bool scale_ue8m0,
bool dummy_is_scale_transposed = false, bool dummy_is_scale_transposed = false,
bool dummy_is_tma_aligned = false) { bool dummy_is_tma_aligned = false) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0); fp8_min, fp8_max, scale_ue8m0);
} }

View File

@@ -0,0 +1,12 @@
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h"
void per_token_group_quant_int8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s,
int64_t group_size, double eps, double int8_min,
double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}

View File

@@ -0,0 +1,10 @@
#pragma once
#include <torch/csrc/stable/tensor.h>
// 8-bit per-token-group quantization helper used by both FP8 and INT8
void per_token_group_quant_8bit(const torch::stable::Tensor& input,
torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s,
int64_t group_size, double eps, double min_8bit,
double max_8bit, bool scale_ue8m0 = false);

View File

@@ -6,15 +6,46 @@
// Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility. // Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility.
// Note: We register under namespace "_C" so ops are accessible as // Note: We register under namespace "_C" so ops are accessible as
// torch.ops._C.<op_name> for compatibility with existing code. // torch.ops._C.<op_name> for compatibility with existing code.
STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) { STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
#ifndef USE_ROCM #ifndef USE_ROCM
m.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
#endif
#ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor.
// The dummy arguments are here so we can correctly fuse with RMSNorm.
ops.def(
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
"output_s, "
"int group_size, float eps, float fp8_min, float fp8_max, bool "
"scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned "
") -> ()");
// Compute per-token-group 8-bit quantized tensor and UE8M0-packed,
// TMA-aligned scales for DeepGEMM.
ops.def(
"per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, "
"Tensor! output_s_packed, int group_size, float eps, float fp8_min, "
"float fp8_max) -> ()");
// Compute per-token-group INT8 quantized tensor and scaling factor.
ops.def(
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
#endif #endif
} }
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
#ifndef USE_ROCM #ifndef USE_ROCM
m.impl("permute_cols", TORCH_BOX(&permute_cols)); ops.impl("permute_cols", TORCH_BOX(&permute_cols));
#endif
#ifndef USE_ROCM
// Per-token group quantization
ops.impl("per_token_group_fp8_quant", TORCH_BOX(&per_token_group_quant_fp8));
ops.impl("per_token_group_fp8_quant_packed",
TORCH_BOX(&per_token_group_quant_8bit_packed));
ops.impl("per_token_group_quant_int8",
TORCH_BOX(&per_token_group_quant_int8));
#endif #endif
} }

View File

@@ -1,11 +1,13 @@
#pragma once #pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/headeronly/util/shim_utils.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
// Utility to get the current CUDA stream for a given device using stable APIs. // Utility to get the current CUDA stream for a given device using stable APIs.
// Returns a cudaStream_t for use in kernel launches. // Returns a cudaStream_t for use in kernel launches.
inline cudaStream_t get_current_cuda_stream(int32_t device_index) { inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {
void* stream_ptr = nullptr; void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK( TORCH_ERROR_CODE_CHECK(
aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));

View File

@@ -306,25 +306,6 @@ void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts, torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts); torch::Tensor const& output_scale_offset_by_experts);
void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0,
bool dummy_is_scale_transposed,
bool dummy_is_tma_aligned);
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max);
// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales.
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s_packed,
int64_t group_size, double eps,
double min_8bit, double max_8bit);
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

View File

@@ -4,7 +4,7 @@
* __device__ layernorm utilities. * __device__ layernorm utilities.
*/ */
#include "quantization/vectorization.cuh" #include "libtorch_stable/quantization/vectorization.cuh"
#include "quantization/utils.cuh" #include "quantization/utils.cuh"
#include "quant_conversions.cuh" #include "quant_conversions.cuh"

View File

@@ -4,7 +4,7 @@
* __device__ helper functions to deal with float -> quant datatype conversion * __device__ helper functions to deal with float -> quant datatype conversion
*/ */
#include "quantization/vectorization.cuh" #include "libtorch_stable/quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead // TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/w8a8/fp8/common.cuh" #include "quantization/w8a8/fp8/common.cuh"

View File

@@ -1,7 +1,7 @@
#include "common.cuh" #include "common.cuh"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "cub_helpers.h" #include "cub_helpers.h"
#include "quantization/vectorization_utils.cuh" #include "libtorch_stable/quantization/vectorization_utils.cuh"
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <tuple> #include <tuple>

View File

@@ -1,6 +1,6 @@
#pragma once #pragma once
#include "quantization/vectorization.cuh" #include "libtorch_stable/quantization/vectorization.cuh"
#include "quantization/utils.cuh" #include "quantization/utils.cuh"
#include <cmath> #include <cmath>

View File

@@ -1,12 +0,0 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include "quantization/w8a8/per_token_group_quant_8bit.h"
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}

View File

@@ -5,7 +5,7 @@
#include <cmath> #include <cmath>
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh" #include "libtorch_stable/quantization/vectorization_utils.cuh"
#include "cub_helpers.h" #include "cub_helpers.h"
static inline __device__ int8_t float_to_int8_rn(float x) { static inline __device__ int8_t float_to_int8_rn(float x) {

View File

@@ -1,9 +0,0 @@
#pragma once
#include <torch/all.h>
// 8-bit per-token-group quantization helper used by both FP8 and INT8
void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double min_8bit, double max_8bit,
bool scale_ue8m0 = false);

View File

@@ -653,34 +653,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
#ifndef USE_ROCM #ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor.
// The dummy arguments are here so we can correctly fuse with RMSNorm.
ops.def(
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
"output_s, "
"int group_size, float eps, float fp8_min, float fp8_max, bool "
"scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned "
") -> ()");
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
&per_token_group_quant_fp8);
// Compute per-token-group 8-bit quantized tensor and UE8M0-packed,
// TMA-aligned scales for DeepGEMM.
ops.def(
"per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, "
"Tensor! output_s_packed, int group_size, float eps, float fp8_min, "
"float fp8_max) -> ()");
ops.impl("per_token_group_fp8_quant_packed", torch::kCUDA,
&per_token_group_quant_8bit_packed);
// Compute per-token-group INT8 quantized tensor and scaling factor.
ops.def(
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
ops.impl("per_token_group_quant_int8", torch::kCUDA,
&per_token_group_quant_int8);
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def( ops.def(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, " "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "