[NVFP4][Perf] Tune NVFP4 input quant kernel for small batch size (#30897)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-12-21 12:41:57 -05:00
committed by GitHub
parent b471092d3a
commit 06d490282f
5 changed files with 243 additions and 97 deletions

View File

@@ -74,6 +74,9 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Precompute SF layout parameter (constant for entire kernel).
int32_t const numKTiles = (numCols + 63) / 64;
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
@@ -101,7 +104,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx, colIdx, numCols, SFout);
rowIdx, colIdx, numKTiles, SFout);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal,
sf_out);

View File

@@ -25,6 +25,7 @@
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
#include "launch_bounds_utils.h"
@@ -44,6 +45,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Precompute SF layout parameter (constant for entire kernel).
int32_t const numKTiles = (numCols + 63) / 64;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
@@ -112,17 +116,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
@@ -140,6 +140,10 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Precompute SF layout parameter (constant for entire kernel).
int32_t const numKTiles = (numCols + 63) / 64;
extern __shared__ uint32_t shared_input_offsets[];
// Load input offsets into shared memory.
@@ -202,16 +206,13 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
@@ -222,12 +223,8 @@ void quant_impl(void* output, void* output_scale, void* input,
void* input_global_scale, void* input_offset_by_experts,
void* output_scale_offset_by_experts, int m_topk, int k,
int n_experts, cudaStream_t stream) {
// TODO: this multiProcessorCount should be cached.
int device;
cudaGetDevice(&device);
int multiProcessorCount;
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount,
device);
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
// Grid, Block size.
// Each thread converts 8 values.

View File

@@ -38,6 +38,12 @@ __host__ __device__ inline Int round_up(Int x, Int y) {
return (x + y - 1) / y * y;
}
// Compute effective rows for grid configuration with swizzled SF layouts.
inline int computeEffectiveRows(int m) {
constexpr int ROW_TILE = 128;
return round_up(m, ROW_TILE);
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
@@ -49,6 +55,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Precompute SF layout parameter (constant for entire kernel).
int32_t const numKTiles = (numCols + 63) / 64;
int sf_m = round_up<int>(numRows, 128);
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
@@ -79,7 +88,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx, colIdx, numCols, SFout);
rowIdx, colIdx, numKTiles, SFout);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
@@ -87,43 +96,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
}
}
template <typename T>
void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount, cudaStream_t stream) {
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
// Get number of blocks per SM
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
// Launch the cvt kernel.
if (useUE8M0) {
cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput));
} else {
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput));
}
}
// Instantiate the function.
template void invokeFP4Quantization(int m, int n, half const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount,
cudaStream_t stream);
template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount,
cudaStream_t stream);
} // namespace vllm
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
@@ -147,13 +119,19 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
// We don't support e8m0 scales at this moment.
bool useUE8M0 = false;
// Grid, Block size. Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
int effectiveRows = vllm::computeEffectiveRows(m);
dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr,
sf_out, useUE8M0, multiProcessorCount, stream);
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, input_ptr, input_sf_ptr, reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}

View File

@@ -128,51 +128,42 @@ inline __device__ float reciprocal_approximate_ftz(float a) {
return b;
}
// Compute SF output offset for swizzled tensor core layout.
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
// Caller must precompute: numKTiles = (numCols + 63) / 64
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
int numCols,
SFType* SFout) {
__device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(
int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) {
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
outerMIdx * outerMStride + innerMIdx * innerMStride +
innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) {
return nullptr;
}
return nullptr;
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// Decompose indices using bitwise ops (all divisors are powers of 2).
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
int32_t mTileIdx = mIdx >> 7; // mIdx / 128
int32_t outerMIdx = mIdx & 31; // mIdx % 32
int32_t innerMIdx = (mIdx >> 5) & 3; // (mIdx / 32) % 4
int32_t kTileIdx = kIdx >> 2; // kIdx / 4
int32_t innerKIdx = kIdx & 3; // kIdx % 4
// Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 +
// outerMIdx * 16 + innerMIdx * 4 + innerKIdx
// Use bitwise OR for non-overlapping lower bits.
int64_t SFOffset = (static_cast<int64_t>(mTileIdx) * numKTiles + kTileIdx)
<< 9 |
(outerMIdx << 4) | (innerMIdx << 2) | innerKIdx;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
// Quantizes the provided PackedVec into the uint32_t output