[Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
committed by
GitHub
parent
1ebdff412a
commit
fcb9df99bd
@@ -27,29 +27,23 @@
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "launch_bounds_utils.h"
|
||||
|
||||
// Define before including nvfp4_utils.cuh so the header
|
||||
// can use this macro during compilation.
|
||||
#define NVFP4_ENABLE_ELTS16 1
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename Int>
|
||||
__host__ __device__ inline Int round_up(Int x, Int y) {
|
||||
static_assert(std::is_integral_v<Int>,
|
||||
"round_up argument must be integral type");
|
||||
return ((x + y - 1) / y) * y;
|
||||
}
|
||||
|
||||
// Compute effective rows for grid configuration with swizzled SF layouts.
|
||||
inline int computeEffectiveRows(int m) {
|
||||
constexpr int ROW_TILE = 128;
|
||||
return round_up(m, ROW_TILE);
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, int32_t num_padded_cols,
|
||||
Type const* __restrict__ in,
|
||||
float const* __restrict__ SFScale,
|
||||
uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) {
|
||||
using PackedVec = vllm::PackedVec<Type>;
|
||||
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
@@ -59,33 +53,31 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
int32_t const numKTiles = (numCols + 63) / 64;
|
||||
|
||||
int sf_m = round_up<int>(numRows, 128);
|
||||
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
|
||||
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
|
||||
int num_padded_cols = sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE;
|
||||
int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;
|
||||
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
|
||||
float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0];
|
||||
|
||||
// Iterate over all rows and cols including padded ones -
|
||||
// ensures we visit every single scale factor address to initialize it.
|
||||
for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) {
|
||||
for (int colIdx = threadIdx.x;
|
||||
colIdx < num_padded_cols / CVT_FP4_ELTS_PER_THREAD;
|
||||
colIdx += blockDim.x) {
|
||||
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
|
||||
|
||||
if (colIdx < num_padded_cols) {
|
||||
PackedVec in_vec;
|
||||
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
||||
|
||||
// If we are outside valid rows OR outside valid columns -> Use Zeros
|
||||
if (rowIdx >= numRows || elem_idx >= numCols) {
|
||||
memset(&in_vec, 0, sizeof(PackedVec));
|
||||
|
||||
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
|
||||
if constexpr (CVT_FP4_PACK16) {
|
||||
ld256_or_zero_cg_u32<Type>(
|
||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
||||
valid);
|
||||
} else {
|
||||
// Valid Region: Load actual data
|
||||
in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
ld128_or_zero_cg_u32<Type>(
|
||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
||||
valid);
|
||||
}
|
||||
|
||||
auto sf_out =
|
||||
@@ -94,13 +86,85 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
rowIdx, colIdx, numKTiles, SFout);
|
||||
|
||||
auto out_val =
|
||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
|
||||
cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
|
||||
in_vec, global_scale, sf_out);
|
||||
|
||||
// We do NOT write output for padding because the 'out' tensor is not
|
||||
// padded.
|
||||
if (rowIdx < numRows && elem_idx < numCols) {
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
out[inOffset] = out_val;
|
||||
if (valid) {
|
||||
if constexpr (CVT_FP4_PACK16) {
|
||||
int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2;
|
||||
uint64_t packed64 =
|
||||
(uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
|
||||
reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
|
||||
} else {
|
||||
out[inOffset] = out_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols,
|
||||
int32_t sf_n_unpadded, Type const* __restrict__ in,
|
||||
float const* __restrict__ SFScale,
|
||||
uint32_t* __restrict__ out,
|
||||
uint32_t* __restrict__ SFout) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
"Vec size is not matched.");
|
||||
|
||||
int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;
|
||||
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0];
|
||||
|
||||
// Iterate over all rows and cols including padded ones -
|
||||
// ensures we visit every single scale factor address to initialize it.
|
||||
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
||||
if (colIdx < sf_n_unpadded) {
|
||||
PackedVec in_vec;
|
||||
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
||||
|
||||
// If we are outside valid rows OR outside valid columns -> Use Zeros
|
||||
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
|
||||
if constexpr (CVT_FP4_PACK16) {
|
||||
ld256_or_zero_cg_u32<Type>(
|
||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
||||
valid);
|
||||
} else {
|
||||
ld128_or_zero_cg_u32<Type>(
|
||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
||||
valid);
|
||||
}
|
||||
|
||||
auto sf_out =
|
||||
sf_out_rowmajor_u8<uint32_t>(rowIdx, colIdx, sf_n_unpadded, SFout);
|
||||
|
||||
auto out_val =
|
||||
cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
|
||||
in_vec, global_scale, sf_out);
|
||||
|
||||
// We do NOT write output for padding because the 'out' tensor is not
|
||||
// padded.
|
||||
if (valid) {
|
||||
if constexpr (CVT_FP4_PACK16) {
|
||||
int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2;
|
||||
uint64_t packed64 =
|
||||
(uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
|
||||
reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
|
||||
} else {
|
||||
out[inOffset] = out_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -111,7 +175,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& output_sf,
|
||||
torch::Tensor const& input_sf) {
|
||||
torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout) {
|
||||
int32_t m = input.size(0);
|
||||
int32_t n = input.size(1);
|
||||
|
||||
@@ -129,19 +194,48 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
|
||||
|
||||
// Grid, Block size. Each thread converts 8 values.
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
int effectiveRows = vllm::computeEffectiveRows(m);
|
||||
dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
// NOTE: We don't support e8m0 scales at this moment.
|
||||
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
|
||||
m, n, input_ptr, input_sf_ptr, reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
}
|
||||
if (is_sf_swizzled_layout) {
|
||||
int sf_n_int = int(vllm::round_up(sf_n_unpadded, 4) / 4);
|
||||
int32_t num_padded_cols =
|
||||
sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
|
||||
|
||||
int grid_y = vllm::div_round_up(num_padded_cols, static_cast<int>(block.x));
|
||||
int grid_x =
|
||||
std::min(vllm::computeEffectiveRows(m),
|
||||
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||
dim3 grid(grid_x, grid_y);
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
// NOTE: We don't support e8m0 scales at this moment.
|
||||
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
|
||||
m, n, num_padded_cols, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
} else {
|
||||
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x));
|
||||
int grid_x = std::min(
|
||||
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||
dim3 grid(grid_x, grid_y);
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
// NOTE: We don't support e8m0 scales at this moment.
|
||||
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
|
||||
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, input_ptr,
|
||||
input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user