[NVFP4][Perf] Tune NVFP4 input quant kernel for small batch size (#30897)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-12-21 12:41:57 -05:00
committed by GitHub
parent b471092d3a
commit 06d490282f
5 changed files with 243 additions and 97 deletions

View File

@@ -25,6 +25,7 @@
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
#include "launch_bounds_utils.h"
@@ -44,6 +45,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Precompute SF layout parameter (constant for entire kernel).
int32_t const numKTiles = (numCols + 63) / 64;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
@@ -112,17 +116,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
@@ -140,6 +140,10 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Precompute SF layout parameter (constant for entire kernel).
int32_t const numKTiles = (numCols + 63) / 64;
extern __shared__ uint32_t shared_input_offsets[];
// Load input offsets into shared memory.
@@ -202,16 +206,13 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
SFout + output_scale_offset_by_experts[expert_idx] * numKTiles;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
@@ -222,12 +223,8 @@ void quant_impl(void* output, void* output_scale, void* input,
void* input_global_scale, void* input_offset_by_experts,
void* output_scale_offset_by_experts, int m_topk, int k,
int n_experts, cudaStream_t stream) {
// TODO: this multiProcessorCount should be cached.
int device;
cudaGetDevice(&device);
int multiProcessorCount;
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount,
device);
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
// Grid, Block size.
// Each thread converts 8 values.