[2/n] Migrate per_token_group_quant to torch stable ABI (#36058)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
@@ -343,9 +343,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||||
"csrc/cutlass_extensions/common.cpp"
|
"csrc/cutlass_extensions/common.cpp")
|
||||||
"csrc/quantization/w8a8/fp8/per_token_group_quant.cu"
|
|
||||||
"csrc/quantization/w8a8/int8/per_token_group_quant.cu")
|
|
||||||
|
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${VLLM_EXT_SRC}"
|
SRCS "${VLLM_EXT_SRC}"
|
||||||
@@ -969,7 +967,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/libtorch_stable/torch_bindings.cpp")
|
"csrc/libtorch_stable/torch_bindings.cpp")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
list(APPEND VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/permute_cols.cu")
|
list(APPEND VLLM_STABLE_EXT_SRC
|
||||||
|
"csrc/libtorch_stable/permute_cols.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
|
|||||||
@@ -7,7 +7,8 @@
|
|||||||
#include "cuda_utils.h"
|
#include "cuda_utils.h"
|
||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "quantization/vectorization_utils.cuh"
|
|
||||||
|
#include "libtorch_stable/quantization/vectorization_utils.cuh"
|
||||||
#include "concat_mla_q.cuh"
|
#include "concat_mla_q.cuh"
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "cub_helpers.h"
|
#include "cub_helpers.h"
|
||||||
#include "core/batch_invariant.hpp"
|
#include "core/batch_invariant.hpp"
|
||||||
#include "quantization/vectorization_utils.cuh"
|
#include "libtorch_stable/quantization/vectorization_utils.cuh"
|
||||||
|
|
||||||
#include <torch/cuda.h>
|
#include <torch/cuda.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "cub_helpers.h"
|
#include "cub_helpers.h"
|
||||||
#include "core/batch_invariant.hpp"
|
#include "core/batch_invariant.hpp"
|
||||||
#include "quantization/vectorization_utils.cuh"
|
#include "libtorch_stable/quantization/vectorization_utils.cuh"
|
||||||
|
|
||||||
#include <torch/cuda.h>
|
#include <torch/cuda.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|||||||
60
csrc/libtorch_stable/dispatch_utils.h
Normal file
60
csrc/libtorch_stable/dispatch_utils.h
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
/*
|
||||||
|
* Stable ABI compatible dispatch utilities for vLLM.
|
||||||
|
* Adapted from dispatch_utils.h to use PyTorch's header-only (THO_*) macros
|
||||||
|
* instead of the ATen (AT_*) macros.
|
||||||
|
*
|
||||||
|
* These macros use:
|
||||||
|
* - THO_DISPATCH_SWITCH instead of AT_DISPATCH_SWITCH
|
||||||
|
* - THO_DISPATCH_CASE instead of AT_DISPATCH_CASE
|
||||||
|
* - torch::headeronly::ScalarType instead of at::ScalarType
|
||||||
|
*
|
||||||
|
* Add more macros here as needed when migrating additional kernels.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/headeronly/core/Dispatch.h>
|
||||||
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
#include <torch/headeronly/util/Exception.h>
|
||||||
|
|
||||||
|
// Need a special dispatch case macro since we will nest the FP8 dispatch.
|
||||||
|
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
|
||||||
|
#define VLLM_STABLE_DISPATCH_FP8_CASE(enum_type, ...) \
|
||||||
|
THO_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
|
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
|
||||||
|
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_STABLE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
|
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
||||||
|
VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
// FP8 type dispatch - ROCm uses FNUZ format, CUDA uses OCP format
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
#define VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(...) \
|
||||||
|
VLLM_STABLE_DISPATCH_FP8_CASE( \
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||||
|
VLLM_STABLE_DISPATCH_FP8_CASE( \
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
|
||||||
|
#else
|
||||||
|
#define VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(...) \
|
||||||
|
VLLM_STABLE_DISPATCH_FP8_CASE( \
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn, __VA_ARGS__)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
|
||||||
|
// See VLLM_STABLE_DISPATCH_FP8_CASE above.
|
||||||
|
#define VLLM_STABLE_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
|
||||||
|
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
||||||
|
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
// Boolean dispatch
|
||||||
|
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
|
||||||
|
if (expr) { \
|
||||||
|
constexpr bool const_expr = true; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
} else { \
|
||||||
|
constexpr bool const_expr = false; \
|
||||||
|
__VA_ARGS__(); \
|
||||||
|
}
|
||||||
@@ -6,4 +6,25 @@
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
|
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
|
||||||
torch::stable::Tensor const& perm);
|
torch::stable::Tensor const& perm);
|
||||||
|
|
||||||
|
void per_token_group_quant_fp8(const torch::stable::Tensor& input,
|
||||||
|
torch::stable::Tensor& output_q,
|
||||||
|
torch::stable::Tensor& output_s,
|
||||||
|
int64_t group_size, double eps, double fp8_min,
|
||||||
|
double fp8_max, bool scale_ue8m0,
|
||||||
|
bool dummy_is_scale_transposed,
|
||||||
|
bool dummy_is_tma_aligned);
|
||||||
|
|
||||||
|
// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales.
|
||||||
|
void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
|
||||||
|
torch::stable::Tensor& output_q,
|
||||||
|
torch::stable::Tensor& output_s_packed,
|
||||||
|
int64_t group_size, double eps,
|
||||||
|
double min_8bit, double max_8bit);
|
||||||
|
|
||||||
|
void per_token_group_quant_int8(const torch::stable::Tensor& input,
|
||||||
|
torch::stable::Tensor& output_q,
|
||||||
|
torch::stable::Tensor& output_s,
|
||||||
|
int64_t group_size, double eps, double int8_min,
|
||||||
|
double int8_max);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -4,8 +4,8 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// Include both AMD and NVIDIA fp8 types to avoid circular import
|
// Include both AMD and NVIDIA fp8 types to avoid circular import
|
||||||
#include <c10/util/Float8_e4m3fnuz.h>
|
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
@@ -1,16 +1,18 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include <torch/csrc/stable/ops.h>
|
||||||
|
#include <torch/headeronly/util/Exception.h>
|
||||||
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
|
||||||
#include "quantization/w8a8/per_token_group_quant_8bit.h"
|
#include "libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
|
|
||||||
#include <torch/all.h>
|
#include "libtorch_stable/quantization/vectorization.cuh"
|
||||||
|
#include "libtorch_stable/quantization/vectorization_utils.cuh"
|
||||||
#include "quantization/vectorization.cuh"
|
#include "libtorch_stable/dispatch_utils.h"
|
||||||
#include "quantization/vectorization_utils.cuh"
|
#include "libtorch_stable/torch_utils.h"
|
||||||
#include "dispatch_utils.h"
|
|
||||||
|
|
||||||
__device__ __forceinline__ float GroupReduceMax(float val) {
|
__device__ __forceinline__ float GroupReduceMax(float val) {
|
||||||
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
|
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
|
||||||
@@ -154,20 +156,20 @@ inline int GetGroupsPerBlock(int64_t num_groups) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
void per_token_group_quant_8bit(const torch::Tensor& input,
|
void per_token_group_quant_8bit(const torch::stable::Tensor& input,
|
||||||
torch::Tensor& output_q,
|
torch::stable::Tensor& output_q,
|
||||||
torch::Tensor& output_s, int64_t group_size,
|
torch::stable::Tensor& output_s,
|
||||||
double eps, double min_8bit, double max_8bit,
|
int64_t group_size, double eps, double min_8bit,
|
||||||
bool scale_ue8m0) {
|
double max_8bit, bool scale_ue8m0) {
|
||||||
TORCH_CHECK(input.is_contiguous());
|
STD_TORCH_CHECK(input.is_contiguous());
|
||||||
TORCH_CHECK(output_q.is_contiguous());
|
STD_TORCH_CHECK(output_q.is_contiguous());
|
||||||
|
|
||||||
const int num_groups = input.numel() / group_size;
|
const int num_groups = input.numel() / group_size;
|
||||||
|
|
||||||
TORCH_CHECK(input.numel() % group_size == 0);
|
STD_TORCH_CHECK(input.numel() % group_size == 0);
|
||||||
TORCH_CHECK(output_s.dim() == 2);
|
STD_TORCH_CHECK(output_s.dim() == 2);
|
||||||
|
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t stream = get_current_cuda_stream();
|
||||||
|
|
||||||
constexpr int THREADS_PER_GROUP = 16;
|
constexpr int THREADS_PER_GROUP = 16;
|
||||||
|
|
||||||
@@ -222,11 +224,11 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
|
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
|
||||||
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) {
|
||||||
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
|
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
|
||||||
} else if (dst_type == at::ScalarType::Char) {
|
} else if (dst_type == torch::headeronly::ScalarType::Char) {
|
||||||
LAUNCH_KERNEL(scalar_t, int8_t);
|
LAUNCH_KERNEL(scalar_t, int8_t);
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
@@ -294,41 +296,42 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
|
|||||||
threads_per_group, y_s, min_8bit, max_8bit);
|
threads_per_group, y_s, min_8bit, max_8bit);
|
||||||
}
|
}
|
||||||
|
|
||||||
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
|
||||||
torch::Tensor& output_q,
|
torch::stable::Tensor& output_q,
|
||||||
torch::Tensor& output_s_packed,
|
torch::stable::Tensor& output_s_packed,
|
||||||
int64_t group_size, double eps,
|
int64_t group_size, double eps,
|
||||||
double min_8bit, double max_8bit) {
|
double min_8bit, double max_8bit) {
|
||||||
TORCH_CHECK(input.is_contiguous());
|
STD_TORCH_CHECK(input.is_contiguous());
|
||||||
TORCH_CHECK(output_q.is_contiguous());
|
STD_TORCH_CHECK(output_q.is_contiguous());
|
||||||
|
|
||||||
const int64_t k = input.size(-1);
|
const int64_t k = input.size(-1);
|
||||||
TORCH_CHECK(k % group_size == 0, "Last dimension (", k,
|
STD_TORCH_CHECK(k % group_size == 0, "Last dimension (", k,
|
||||||
") must be divisible by group_size (", group_size, ").");
|
") must be divisible by group_size (", group_size, ").");
|
||||||
|
|
||||||
const int64_t mn = input.numel() / k;
|
const int64_t mn = input.numel() / k;
|
||||||
const int64_t groups_per_row = k / group_size;
|
const int64_t groups_per_row = k / group_size;
|
||||||
const int64_t num_groups = mn * groups_per_row;
|
const int64_t num_groups = mn * groups_per_row;
|
||||||
|
|
||||||
TORCH_CHECK(output_s_packed.dim() == 2,
|
STD_TORCH_CHECK(output_s_packed.dim() == 2,
|
||||||
"output_s_packed must be 2D, got dim=", output_s_packed.dim(),
|
"output_s_packed must be 2D, got dim=", output_s_packed.dim(),
|
||||||
".");
|
".");
|
||||||
|
|
||||||
const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4;
|
const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4;
|
||||||
const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4;
|
const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4;
|
||||||
|
|
||||||
TORCH_CHECK(output_s_packed.scalar_type() == at::ScalarType::Int,
|
STD_TORCH_CHECK(
|
||||||
"output_s_packed must have dtype int32 for UE8M0-packed scales.");
|
output_s_packed.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||||
|
"output_s_packed must have dtype int32 for UE8M0-packed scales.");
|
||||||
// DeepGEMM expects SFA scales in MN-major form with shape
|
// DeepGEMM expects SFA scales in MN-major form with shape
|
||||||
// [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last
|
// [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last
|
||||||
// dimension.
|
// dimension.
|
||||||
TORCH_CHECK(output_s_packed.size(0) == mn &&
|
STD_TORCH_CHECK(output_s_packed.size(0) == mn &&
|
||||||
output_s_packed.size(1) == k_num_packed_sfk,
|
output_s_packed.size(1) == k_num_packed_sfk,
|
||||||
"output_s_packed shape must be [", mn, ", ", k_num_packed_sfk,
|
"output_s_packed shape must be [", mn, ", ", k_num_packed_sfk,
|
||||||
"], but got [", output_s_packed.size(0), ", ",
|
"], but got [", output_s_packed.size(0), ", ",
|
||||||
output_s_packed.size(1), "].");
|
output_s_packed.size(1), "].");
|
||||||
|
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t stream = get_current_cuda_stream();
|
||||||
|
|
||||||
constexpr int THREADS_PER_GROUP = 16;
|
constexpr int THREADS_PER_GROUP = 16;
|
||||||
|
|
||||||
@@ -340,7 +343,7 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
|||||||
|
|
||||||
// zero-initialize packed scales, since we use atomicOr to accumulate
|
// zero-initialize packed scales, since we use atomicOr to accumulate
|
||||||
// exponents from different groups.
|
// exponents from different groups.
|
||||||
output_s_packed.zero_();
|
torch::stable::zero_(output_s_packed);
|
||||||
|
|
||||||
#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \
|
#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \
|
||||||
do { \
|
do { \
|
||||||
@@ -359,14 +362,14 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
|||||||
static_cast<float>(max_8bit)); \
|
static_cast<float>(max_8bit)); \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] {
|
input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] {
|
||||||
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
if (dst_type == torch::headeronly::ScalarType::Float8_e4m3fn) {
|
||||||
LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3);
|
LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3);
|
||||||
} else if (dst_type == at::ScalarType::Char) {
|
} else if (dst_type == torch::headeronly::ScalarType::Char) {
|
||||||
LAUNCH_PACKED_KERNEL(scalar_t, int8_t);
|
LAUNCH_PACKED_KERNEL(scalar_t, int8_t);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
STD_TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"per_token_group_quant_8bit_packed only supports FP8/INT8 "
|
"per_token_group_quant_8bit_packed only supports FP8/INT8 "
|
||||||
"outputs.");
|
"outputs.");
|
||||||
@@ -376,12 +379,13 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
|||||||
#undef LAUNCH_PACKED_KERNEL
|
#undef LAUNCH_PACKED_KERNEL
|
||||||
}
|
}
|
||||||
|
|
||||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
void per_token_group_quant_fp8(const torch::stable::Tensor& input,
|
||||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
torch::stable::Tensor& output_q,
|
||||||
|
torch::stable::Tensor& output_s,
|
||||||
int64_t group_size, double eps, double fp8_min,
|
int64_t group_size, double eps, double fp8_min,
|
||||||
double fp8_max, bool scale_ue8m0,
|
double fp8_max, bool scale_ue8m0,
|
||||||
bool dummy_is_scale_transposed = false,
|
bool dummy_is_scale_transposed = false,
|
||||||
bool dummy_is_tma_aligned = false) {
|
bool dummy_is_tma_aligned = false) {
|
||||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||||
fp8_min, fp8_max, scale_ue8m0);
|
fp8_min, fp8_max, scale_ue8m0);
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
|
||||||
|
#include "libtorch_stable/quantization/w8a8/per_token_group_quant_8bit.h"
|
||||||
|
|
||||||
|
void per_token_group_quant_int8(const torch::stable::Tensor& input,
|
||||||
|
torch::stable::Tensor& output_q,
|
||||||
|
torch::stable::Tensor& output_s,
|
||||||
|
int64_t group_size, double eps, double int8_min,
|
||||||
|
double int8_max) {
|
||||||
|
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||||
|
int8_min, int8_max);
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
|
||||||
|
// 8-bit per-token-group quantization helper used by both FP8 and INT8
|
||||||
|
void per_token_group_quant_8bit(const torch::stable::Tensor& input,
|
||||||
|
torch::stable::Tensor& output_q,
|
||||||
|
torch::stable::Tensor& output_s,
|
||||||
|
int64_t group_size, double eps, double min_8bit,
|
||||||
|
double max_8bit, bool scale_ue8m0 = false);
|
||||||
@@ -6,15 +6,46 @@
|
|||||||
// Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility.
|
// Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility.
|
||||||
// Note: We register under namespace "_C" so ops are accessible as
|
// Note: We register under namespace "_C" so ops are accessible as
|
||||||
// torch.ops._C.<op_name> for compatibility with existing code.
|
// torch.ops._C.<op_name> for compatibility with existing code.
|
||||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) {
|
STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
m.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
// Compute per-token-group FP8 quantized tensor and scaling factor.
|
||||||
|
// The dummy arguments are here so we can correctly fuse with RMSNorm.
|
||||||
|
ops.def(
|
||||||
|
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
|
||||||
|
"output_s, "
|
||||||
|
"int group_size, float eps, float fp8_min, float fp8_max, bool "
|
||||||
|
"scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned "
|
||||||
|
") -> ()");
|
||||||
|
// Compute per-token-group 8-bit quantized tensor and UE8M0-packed,
|
||||||
|
// TMA-aligned scales for DeepGEMM.
|
||||||
|
ops.def(
|
||||||
|
"per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, "
|
||||||
|
"Tensor! output_s_packed, int group_size, float eps, float fp8_min, "
|
||||||
|
"float fp8_max) -> ()");
|
||||||
|
// Compute per-token-group INT8 quantized tensor and scaling factor.
|
||||||
|
ops.def(
|
||||||
|
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
|
||||||
|
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||||
|
"()");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
m.impl("permute_cols", TORCH_BOX(&permute_cols));
|
ops.impl("permute_cols", TORCH_BOX(&permute_cols));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
// Per-token group quantization
|
||||||
|
ops.impl("per_token_group_fp8_quant", TORCH_BOX(&per_token_group_quant_fp8));
|
||||||
|
ops.impl("per_token_group_fp8_quant_packed",
|
||||||
|
TORCH_BOX(&per_token_group_quant_8bit_packed));
|
||||||
|
ops.impl("per_token_group_quant_int8",
|
||||||
|
TORCH_BOX(&per_token_group_quant_int8));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||||
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
// Utility to get the current CUDA stream for a given device using stable APIs.
|
// Utility to get the current CUDA stream for a given device using stable APIs.
|
||||||
// Returns a cudaStream_t for use in kernel launches.
|
// Returns a cudaStream_t for use in kernel launches.
|
||||||
inline cudaStream_t get_current_cuda_stream(int32_t device_index) {
|
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {
|
||||||
void* stream_ptr = nullptr;
|
void* stream_ptr = nullptr;
|
||||||
TORCH_ERROR_CODE_CHECK(
|
TORCH_ERROR_CODE_CHECK(
|
||||||
aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));
|
aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));
|
||||||
|
|||||||
19
csrc/ops.h
19
csrc/ops.h
@@ -306,25 +306,6 @@ void silu_and_mul_scaled_fp4_experts_quant(
|
|||||||
torch::Tensor const& input_offset_by_experts,
|
torch::Tensor const& input_offset_by_experts,
|
||||||
torch::Tensor const& output_scale_offset_by_experts);
|
torch::Tensor const& output_scale_offset_by_experts);
|
||||||
|
|
||||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
|
||||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
|
||||||
int64_t group_size, double eps, double fp8_min,
|
|
||||||
double fp8_max, bool scale_ue8m0,
|
|
||||||
bool dummy_is_scale_transposed,
|
|
||||||
bool dummy_is_tma_aligned);
|
|
||||||
|
|
||||||
void per_token_group_quant_int8(const torch::Tensor& input,
|
|
||||||
torch::Tensor& output_q,
|
|
||||||
torch::Tensor& output_s, int64_t group_size,
|
|
||||||
double eps, double int8_min, double int8_max);
|
|
||||||
|
|
||||||
// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales.
|
|
||||||
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
|
||||||
torch::Tensor& output_q,
|
|
||||||
torch::Tensor& output_s_packed,
|
|
||||||
int64_t group_size, double eps,
|
|
||||||
double min_8bit, double max_8bit);
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
* __device__ layernorm utilities.
|
* __device__ layernorm utilities.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "quantization/vectorization.cuh"
|
#include "libtorch_stable/quantization/vectorization.cuh"
|
||||||
#include "quantization/utils.cuh"
|
#include "quantization/utils.cuh"
|
||||||
#include "quant_conversions.cuh"
|
#include "quant_conversions.cuh"
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
* __device__ helper functions to deal with float -> quant datatype conversion
|
* __device__ helper functions to deal with float -> quant datatype conversion
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "quantization/vectorization.cuh"
|
#include "libtorch_stable/quantization/vectorization.cuh"
|
||||||
// TODO(luka/varun):refactor common.cuh to use this file instead
|
// TODO(luka/varun):refactor common.cuh to use this file instead
|
||||||
#include "quantization/w8a8/fp8/common.cuh"
|
#include "quantization/w8a8/fp8/common.cuh"
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "cub_helpers.h"
|
#include "cub_helpers.h"
|
||||||
#include "quantization/vectorization_utils.cuh"
|
#include "libtorch_stable/quantization/vectorization_utils.cuh"
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <ATen/cuda/Exceptions.h>
|
#include <ATen/cuda/Exceptions.h>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "quantization/vectorization.cuh"
|
#include "libtorch_stable/quantization/vectorization.cuh"
|
||||||
#include "quantization/utils.cuh"
|
#include "quantization/utils.cuh"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
#include "quantization/w8a8/per_token_group_quant_8bit.h"
|
|
||||||
|
|
||||||
void per_token_group_quant_int8(const torch::Tensor& input,
|
|
||||||
torch::Tensor& output_q,
|
|
||||||
torch::Tensor& output_s, int64_t group_size,
|
|
||||||
double eps, double int8_min, double int8_max) {
|
|
||||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
|
||||||
int8_min, int8_max);
|
|
||||||
}
|
|
||||||
@@ -5,7 +5,7 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "quantization/vectorization_utils.cuh"
|
#include "libtorch_stable/quantization/vectorization_utils.cuh"
|
||||||
#include "cub_helpers.h"
|
#include "cub_helpers.h"
|
||||||
|
|
||||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
// 8-bit per-token-group quantization helper used by both FP8 and INT8
|
|
||||||
void per_token_group_quant_8bit(const torch::Tensor& input,
|
|
||||||
torch::Tensor& output_q,
|
|
||||||
torch::Tensor& output_s, int64_t group_size,
|
|
||||||
double eps, double min_8bit, double max_8bit,
|
|
||||||
bool scale_ue8m0 = false);
|
|
||||||
@@ -653,34 +653,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
|
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// Compute per-token-group FP8 quantized tensor and scaling factor.
|
|
||||||
// The dummy arguments are here so we can correctly fuse with RMSNorm.
|
|
||||||
ops.def(
|
|
||||||
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
|
|
||||||
"output_s, "
|
|
||||||
"int group_size, float eps, float fp8_min, float fp8_max, bool "
|
|
||||||
"scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned "
|
|
||||||
") -> ()");
|
|
||||||
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
|
|
||||||
&per_token_group_quant_fp8);
|
|
||||||
|
|
||||||
// Compute per-token-group 8-bit quantized tensor and UE8M0-packed,
|
|
||||||
// TMA-aligned scales for DeepGEMM.
|
|
||||||
ops.def(
|
|
||||||
"per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, "
|
|
||||||
"Tensor! output_s_packed, int group_size, float eps, float fp8_min, "
|
|
||||||
"float fp8_max) -> ()");
|
|
||||||
ops.impl("per_token_group_fp8_quant_packed", torch::kCUDA,
|
|
||||||
&per_token_group_quant_8bit_packed);
|
|
||||||
|
|
||||||
// Compute per-token-group INT8 quantized tensor and scaling factor.
|
|
||||||
ops.def(
|
|
||||||
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
|
|
||||||
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
|
||||||
"()");
|
|
||||||
ops.impl("per_token_group_quant_int8", torch::kCUDA,
|
|
||||||
&per_token_group_quant_int8);
|
|
||||||
|
|
||||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||||
ops.def(
|
ops.def(
|
||||||
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
||||||
|
|||||||
Reference in New Issue
Block a user