[Perf][Kernel] Fused SiLU+Mul+Quant kernel for NVFP4 cutlass_moe (#31832)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -31,8 +31,12 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// NVFP4 quantization kernel for experts (low-latency path).
|
||||
// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses
|
||||
// SiLU(gate)*up before quantization.
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
template <class Type, bool FUSE_SILU_MUL = false, bool UE8M0_SF = false,
|
||||
bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
@@ -50,6 +54,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
// When fusing SiLU+Mul, input has gate || up layout (doubled width)
|
||||
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
|
||||
|
||||
// Each global thread processes one element
|
||||
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
|
||||
@@ -58,13 +64,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
int rowIdx = globalIdx / colsPerRow;
|
||||
int colIdx = globalIdx % colsPerRow;
|
||||
|
||||
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = inOffset;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Find index within the experts using different strategies based on expert
|
||||
// count
|
||||
int rowIdx_in_expert = 0;
|
||||
@@ -111,6 +110,23 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
}
|
||||
}
|
||||
|
||||
// Load input and optionally apply fused SiLU+Mul
|
||||
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
PackedVec quant_input;
|
||||
if constexpr (FUSE_SILU_MUL) {
|
||||
PackedVec in_vec_up =
|
||||
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
|
||||
quant_input = compute_silu_mul(in_vec, in_vec_up);
|
||||
} else {
|
||||
quant_input = in_vec;
|
||||
}
|
||||
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = rowIdx * colsPerRow + colIdx;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
@@ -124,12 +140,16 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
out_pos =
|
||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
|
||||
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
// NVFP4 quantization kernel for LARGE_M_TOPK = true (large m_topk optimized
|
||||
// version). When FUSE_SILU_MUL=true, expects input with gate||up layout and
|
||||
// fuses SiLU(gate)*up before quantization.
|
||||
template <class Type, bool FUSE_SILU_MUL = false, bool UE8M0_SF = false,
|
||||
bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
@@ -167,6 +187,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
// When fusing SiLU+Mul, input has gate || up layout (doubled width)
|
||||
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
|
||||
|
||||
// Each global thread processes one element
|
||||
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
|
||||
@@ -175,11 +197,6 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
int rowIdx = globalIdx / colsPerRow;
|
||||
int colIdx = globalIdx % colsPerRow;
|
||||
|
||||
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
int64_t outOffset = inOffset;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Find expert using binary search for better performance with large m_topk
|
||||
int rowIdx_in_expert = 0;
|
||||
int expert_idx = 0;
|
||||
@@ -204,6 +221,21 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
}
|
||||
}
|
||||
|
||||
// Load input and optionally apply fused SiLU+Mul
|
||||
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
PackedVec quant_input;
|
||||
if constexpr (FUSE_SILU_MUL) {
|
||||
PackedVec in_vec_up =
|
||||
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
|
||||
quant_input = compute_silu_mul(in_vec, in_vec_up);
|
||||
} else {
|
||||
quant_input = in_vec;
|
||||
}
|
||||
|
||||
int64_t outOffset = rowIdx * colsPerRow + colIdx;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
uint32_t* SFout_in_expert =
|
||||
@@ -214,11 +246,12 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
out_pos =
|
||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool FUSE_SILU_MUL = false>
|
||||
void quant_impl(void* output, void* output_scale, void* input,
|
||||
void* input_global_scale, void* input_offset_by_experts,
|
||||
void* output_scale_offset_by_experts, int m_topk, int k,
|
||||
@@ -246,7 +279,7 @@ void quant_impl(void* output, void* output_scale, void* input,
|
||||
if (blockRepeat > 1) {
|
||||
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
|
||||
if (n_experts >= 4) {
|
||||
cvt_fp16_to_fp4<T, false, false>
|
||||
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
|
||||
<<<grid, block, shared_mem_size, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
@@ -256,34 +289,37 @@ void quant_impl(void* output, void* output_scale, void* input,
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts);
|
||||
} else {
|
||||
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts);
|
||||
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, true>
|
||||
<<<grid, block, shared_mem_size, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts);
|
||||
}
|
||||
} else {
|
||||
if (n_experts >= 16) {
|
||||
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts, /* bool low_latency */ true);
|
||||
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts, /* bool low_latency */ true);
|
||||
} else {
|
||||
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts, /* bool low_latency */ true);
|
||||
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, true>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts, /* bool low_latency */ true);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -304,19 +340,19 @@ constexpr auto FLOAT = at::ScalarType::Float;
|
||||
constexpr auto INT = at::ScalarType::Int;
|
||||
constexpr auto UINT8 = at::ScalarType::Byte;
|
||||
|
||||
void scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
// Common validation for fp4 experts quantization entry points.
|
||||
static void validate_fp4_experts_quant_inputs(
|
||||
torch::Tensor const& output, torch::Tensor const& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
CHECK_INPUT(output, "output must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input, "input must be a CUDA tensor");
|
||||
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input_offset_by_experts,
|
||||
"input_offset_by_experts must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale_offset_by_experts,
|
||||
"output_scale_offset_by_experts must be a CUDA tensor");
|
||||
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
|
||||
int64_t k) {
|
||||
CHECK_INPUT(output, "output");
|
||||
CHECK_INPUT(output_scale, "output_scale");
|
||||
CHECK_INPUT(input, "input");
|
||||
CHECK_INPUT(input_global_scale, "input_global_scale");
|
||||
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
|
||||
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2);
|
||||
TORCH_CHECK(output_scale.dim() == 2);
|
||||
@@ -335,8 +371,6 @@ void scaled_fp4_experts_quant_sm1xxa(
|
||||
TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
|
||||
const int BLOCK_SIZE = 16;
|
||||
auto m_topk = input.size(0);
|
||||
auto k = input.size(1);
|
||||
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
@@ -348,7 +382,21 @@ void scaled_fp4_experts_quant_sm1xxa(
|
||||
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
||||
// 4 means 4 fp8 values are packed into one int32
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
auto m_topk = input.size(0);
|
||||
auto k = input.size(1);
|
||||
|
||||
validate_fp4_experts_quant_inputs(output, output_scale, input,
|
||||
input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts, m_topk, k);
|
||||
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
@@ -356,7 +404,38 @@ void scaled_fp4_experts_quant_sm1xxa(
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
vllm::quant_impl<cuda_type>(
|
||||
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
|
||||
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
|
||||
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
|
||||
stream);
|
||||
});
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
auto m_topk = input.size(0);
|
||||
// Input has gate || up layout, so k = input.size(1) / 2
|
||||
auto k_times_2 = input.size(1);
|
||||
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
|
||||
auto k = k_times_2 / 2;
|
||||
|
||||
validate_fp4_experts_quant_inputs(output, output_scale, input,
|
||||
input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts, m_topk, k);
|
||||
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
|
||||
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
|
||||
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
|
||||
|
||||
Reference in New Issue
Block a user