[NVFP4][Perf] Tune NVFP4 input quant kernel for small batch size (#30897)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-12-21 12:41:57 -05:00
committed by GitHub
parent b471092d3a
commit 06d490282f
5 changed files with 243 additions and 97 deletions

View File

@@ -38,6 +38,12 @@ __host__ __device__ inline Int round_up(Int x, Int y) {
return (x + y - 1) / y * y;
}
// Compute effective rows for grid configuration with swizzled SF layouts.
inline int computeEffectiveRows(int m) {
constexpr int ROW_TILE = 128;
return round_up(m, ROW_TILE);
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
@@ -49,6 +55,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Precompute SF layout parameter (constant for entire kernel).
int32_t const numKTiles = (numCols + 63) / 64;
int sf_m = round_up<int>(numRows, 128);
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
@@ -79,7 +88,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx, colIdx, numCols, SFout);
rowIdx, colIdx, numKTiles, SFout);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
@@ -87,43 +96,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
}
}
template <typename T>
void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount, cudaStream_t stream) {
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
// Get number of blocks per SM
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
// Launch the cvt kernel.
if (useUE8M0) {
cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput));
} else {
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput));
}
}
// Instantiate the function.
template void invokeFP4Quantization(int m, int n, half const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount,
cudaStream_t stream);
template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount,
cudaStream_t stream);
} // namespace vllm
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
@@ -147,13 +119,19 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
// We don't support e8m0 scales at this moment.
bool useUE8M0 = false;
// Grid, Block size. Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
int effectiveRows = vllm::computeEffectiveRows(m);
dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM));
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr,
sf_out, useUE8M0, multiProcessorCount, stream);
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, input_ptr, input_sf_ptr, reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}