[Perf][Kernel] Fused SiLU+Mul+Quant kernel for NVFP4 cutlass_moe (#31832)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -301,6 +301,12 @@ void scaled_fp4_experts_quant(
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
|
||||
@@ -31,37 +31,6 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// silu in float32
|
||||
__device__ __forceinline__ float silu(float x) {
|
||||
return __fdividef(x, (1.f + __expf(-x)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float2 silu2(float2 x) {
|
||||
return make_float2(silu(x.x), silu(x.y));
|
||||
}
|
||||
|
||||
template <class Type>
|
||||
__inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2) {
|
||||
PackedVec<Type> result;
|
||||
using packed_type = typename TypeConverter<Type>::Type;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
|
||||
// silu_mul in float32
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
float2 silu_vec = silu2(__half22float2(vec.elts[i]));
|
||||
result.elts[i] =
|
||||
__float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i])));
|
||||
} else {
|
||||
float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i]));
|
||||
result.elts[i] = __float22bfloat162_rn(
|
||||
__fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i])));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
|
||||
@@ -31,8 +31,12 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// NVFP4 quantization kernel for experts (low-latency path).
|
||||
// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses
|
||||
// SiLU(gate)*up before quantization.
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
template <class Type, bool FUSE_SILU_MUL = false, bool UE8M0_SF = false,
|
||||
bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
@@ -50,6 +54,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
// When fusing SiLU+Mul, input has gate || up layout (doubled width)
|
||||
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
|
||||
|
||||
// Each global thread processes one element
|
||||
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
|
||||
@@ -58,13 +64,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
int rowIdx = globalIdx / colsPerRow;
|
||||
int colIdx = globalIdx % colsPerRow;
|
||||
|
||||
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = inOffset;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Find index within the experts using different strategies based on expert
|
||||
// count
|
||||
int rowIdx_in_expert = 0;
|
||||
@@ -111,6 +110,23 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
}
|
||||
}
|
||||
|
||||
// Load input and optionally apply fused SiLU+Mul
|
||||
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
PackedVec quant_input;
|
||||
if constexpr (FUSE_SILU_MUL) {
|
||||
PackedVec in_vec_up =
|
||||
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
|
||||
quant_input = compute_silu_mul(in_vec, in_vec_up);
|
||||
} else {
|
||||
quant_input = in_vec;
|
||||
}
|
||||
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = rowIdx * colsPerRow + colIdx;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
@@ -124,12 +140,16 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
out_pos =
|
||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
|
||||
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
// NVFP4 quantization kernel for LARGE_M_TOPK = true (large m_topk optimized
|
||||
// version). When FUSE_SILU_MUL=true, expects input with gate||up layout and
|
||||
// fuses SiLU(gate)*up before quantization.
|
||||
template <class Type, bool FUSE_SILU_MUL = false, bool UE8M0_SF = false,
|
||||
bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
@@ -167,6 +187,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
// When fusing SiLU+Mul, input has gate || up layout (doubled width)
|
||||
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
|
||||
|
||||
// Each global thread processes one element
|
||||
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
|
||||
@@ -175,11 +197,6 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
int rowIdx = globalIdx / colsPerRow;
|
||||
int colIdx = globalIdx % colsPerRow;
|
||||
|
||||
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
int64_t outOffset = inOffset;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Find expert using binary search for better performance with large m_topk
|
||||
int rowIdx_in_expert = 0;
|
||||
int expert_idx = 0;
|
||||
@@ -204,6 +221,21 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
}
|
||||
}
|
||||
|
||||
// Load input and optionally apply fused SiLU+Mul
|
||||
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
PackedVec quant_input;
|
||||
if constexpr (FUSE_SILU_MUL) {
|
||||
PackedVec in_vec_up =
|
||||
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
|
||||
quant_input = compute_silu_mul(in_vec, in_vec_up);
|
||||
} else {
|
||||
quant_input = in_vec;
|
||||
}
|
||||
|
||||
int64_t outOffset = rowIdx * colsPerRow + colIdx;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
uint32_t* SFout_in_expert =
|
||||
@@ -214,11 +246,12 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
out_pos =
|
||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool FUSE_SILU_MUL = false>
|
||||
void quant_impl(void* output, void* output_scale, void* input,
|
||||
void* input_global_scale, void* input_offset_by_experts,
|
||||
void* output_scale_offset_by_experts, int m_topk, int k,
|
||||
@@ -246,7 +279,7 @@ void quant_impl(void* output, void* output_scale, void* input,
|
||||
if (blockRepeat > 1) {
|
||||
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
|
||||
if (n_experts >= 4) {
|
||||
cvt_fp16_to_fp4<T, false, false>
|
||||
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
|
||||
<<<grid, block, shared_mem_size, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
@@ -256,34 +289,37 @@ void quant_impl(void* output, void* output_scale, void* input,
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts);
|
||||
} else {
|
||||
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts);
|
||||
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, true>
|
||||
<<<grid, block, shared_mem_size, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts);
|
||||
}
|
||||
} else {
|
||||
if (n_experts >= 16) {
|
||||
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts, /* bool low_latency */ true);
|
||||
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts, /* bool low_latency */ true);
|
||||
} else {
|
||||
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts, /* bool low_latency */ true);
|
||||
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, true>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts, /* bool low_latency */ true);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -304,19 +340,19 @@ constexpr auto FLOAT = at::ScalarType::Float;
|
||||
constexpr auto INT = at::ScalarType::Int;
|
||||
constexpr auto UINT8 = at::ScalarType::Byte;
|
||||
|
||||
void scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
// Common validation for fp4 experts quantization entry points.
|
||||
static void validate_fp4_experts_quant_inputs(
|
||||
torch::Tensor const& output, torch::Tensor const& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
CHECK_INPUT(output, "output must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input, "input must be a CUDA tensor");
|
||||
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input_offset_by_experts,
|
||||
"input_offset_by_experts must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale_offset_by_experts,
|
||||
"output_scale_offset_by_experts must be a CUDA tensor");
|
||||
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
|
||||
int64_t k) {
|
||||
CHECK_INPUT(output, "output");
|
||||
CHECK_INPUT(output_scale, "output_scale");
|
||||
CHECK_INPUT(input, "input");
|
||||
CHECK_INPUT(input_global_scale, "input_global_scale");
|
||||
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
|
||||
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2);
|
||||
TORCH_CHECK(output_scale.dim() == 2);
|
||||
@@ -335,8 +371,6 @@ void scaled_fp4_experts_quant_sm1xxa(
|
||||
TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
|
||||
const int BLOCK_SIZE = 16;
|
||||
auto m_topk = input.size(0);
|
||||
auto k = input.size(1);
|
||||
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
@@ -348,7 +382,21 @@ void scaled_fp4_experts_quant_sm1xxa(
|
||||
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
||||
// 4 means 4 fp8 values are packed into one int32
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
auto m_topk = input.size(0);
|
||||
auto k = input.size(1);
|
||||
|
||||
validate_fp4_experts_quant_inputs(output, output_scale, input,
|
||||
input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts, m_topk, k);
|
||||
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
@@ -356,7 +404,38 @@ void scaled_fp4_experts_quant_sm1xxa(
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
vllm::quant_impl<cuda_type>(
|
||||
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
|
||||
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
|
||||
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
|
||||
stream);
|
||||
});
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
auto m_topk = input.size(0);
|
||||
// Input has gate || up layout, so k = input.size(1) / 2
|
||||
auto k_times_2 = input.size(1);
|
||||
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
|
||||
auto k = k_times_2 / 2;
|
||||
|
||||
validate_fp4_experts_quant_inputs(output, output_scale, input,
|
||||
input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts, m_topk, k);
|
||||
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
|
||||
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
|
||||
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
|
||||
|
||||
@@ -41,6 +41,15 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
|
||||
torch::Tensor& input_sf);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
torch::Tensor& output_sf, torch::Tensor const& input_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
@@ -74,3 +83,18 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled silu_and_mul nvfp4 quantization kernel");
|
||||
}
|
||||
|
||||
void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
@@ -239,4 +239,34 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
|
||||
return e2m1Vec;
|
||||
}
|
||||
|
||||
// silu in float32
|
||||
__device__ __forceinline__ float silu(float x) {
|
||||
return __fdividef(x, (1.f + __expf(-x)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float2 silu2(float2 x) {
|
||||
return make_float2(silu(x.x), silu(x.y));
|
||||
}
|
||||
|
||||
template <class Type>
|
||||
__inline__ __device__ PackedVec<Type> compute_silu_mul(
|
||||
const PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
|
||||
PackedVec<Type> result;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
|
||||
// silu_mul in float32
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
float2 silu_vec = silu2(__half22float2(x_vec.elts[i]));
|
||||
result.elts[i] = __float22half2_rn(
|
||||
__fmul2_rn(silu_vec, __half22float2(y_vec.elts[i])));
|
||||
} else {
|
||||
float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i]));
|
||||
result.elts[i] = __float22bfloat162_rn(
|
||||
__fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i])));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@@ -558,6 +558,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
|
||||
|
||||
// Fused SiLU+Mul+NVFP4 experts quantization.
|
||||
ops.def(
|
||||
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
|
||||
"output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA,
|
||||
&silu_and_mul_scaled_fp4_experts_quant);
|
||||
|
||||
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
|
||||
// of the given capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
|
||||
|
||||
@@ -1606,15 +1606,15 @@ def scaled_fp4_experts_quant(
|
||||
topk: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
Quantize input tensor to NVFP4 and return quantized tensor and scale, for
|
||||
packed MoE Inputs.
|
||||
Args:
|
||||
input_tensor: The input tensor to be quantized to FP4
|
||||
input_tensor: The input tensor to be quantized to NVFP4
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
expert_offsets: The expert offsets tensor
|
||||
blockscale_offsets: The blockscale offsets tensor
|
||||
Outputs:
|
||||
output: The quantized tensor in FP4
|
||||
output: The quantized tensor in NVFP4
|
||||
output_scales: The blockscale tensor in FP8-E4M3
|
||||
"""
|
||||
assert not current_platform.is_rocm()
|
||||
@@ -1660,6 +1660,71 @@ def scaled_fp4_experts_quant(
|
||||
return output, output_scales
|
||||
|
||||
|
||||
def silu_and_mul_scaled_fp4_experts_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
topk: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fused SiLU+Mul+NVFP4 quantization for MoE intermediate activations.
|
||||
|
||||
Args:
|
||||
input_tensor: The input tensor with gate || up layout [m_topk, k*2]
|
||||
input_global_scale: A per-expert scaling factor [n_experts]
|
||||
expert_offsets: The expert offsets tensor [n_experts+1]
|
||||
blockscale_offsets: The blockscale offsets tensor [n_experts+1]
|
||||
topk: Number of top-k experts selected
|
||||
Outputs:
|
||||
output: The quantized tensor in NVFP4 [m_topk, k/2]
|
||||
output_scales: The blockscale tensor in FP8-E4M3
|
||||
"""
|
||||
assert not current_platform.is_rocm()
|
||||
assert input_tensor.ndim == 2, (
|
||||
f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
|
||||
)
|
||||
|
||||
# Control the maximum number of tokens per expert supported by the
|
||||
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
|
||||
# from running out of memory. This value can also be increased to support
|
||||
# larger models.
|
||||
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
|
||||
m_numtopk, k_times_2 = input_tensor.shape
|
||||
assert k_times_2 % 2 == 0, "input width must be even (gate || up layout)"
|
||||
k = k_times_2 // 2
|
||||
|
||||
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
|
||||
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
|
||||
f"{MAX_TOKENS_PER_EXPERT})"
|
||||
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
|
||||
f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value."
|
||||
)
|
||||
scales_k = k // 16
|
||||
padded_k = (scales_k + (4 - 1)) // 4
|
||||
|
||||
# output is uint8 and packed fp4 values
|
||||
output = torch.empty(
|
||||
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
|
||||
)
|
||||
output_scales = torch.empty(
|
||||
MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
torch.ops._C.silu_and_mul_scaled_fp4_experts_quant(
|
||||
output,
|
||||
output_scales,
|
||||
input_tensor,
|
||||
input_global_scale,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
)
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn)
|
||||
return output, output_scales
|
||||
|
||||
|
||||
# fp8
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
|
||||
@@ -549,7 +549,8 @@ def run_cutlass_moe_fp4(
|
||||
num_topk,
|
||||
)
|
||||
c1 = _resize_cache(workspace13, (m * topk, n * 2))
|
||||
c2 = _resize_cache(workspace2, (m * topk, n))
|
||||
# Note: c2 workspace is no longer needed since SiLU is fused with quantization.
|
||||
# c3 reuses workspace13 after c1 is consumed.
|
||||
c3 = _resize_cache(workspace13, (m * topk, k))
|
||||
ops.cutlass_fp4_moe_mm(
|
||||
c1,
|
||||
@@ -563,9 +564,9 @@ def run_cutlass_moe_fp4(
|
||||
blockscale_offsets[:-1],
|
||||
)
|
||||
del rep_a_fp4, rep_a_blockscale
|
||||
torch.ops._C.silu_and_mul(c2, c1)
|
||||
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
|
||||
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk
|
||||
# Fused SiLU+Mul+NVFP4 quantization
|
||||
int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant(
|
||||
c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk
|
||||
)
|
||||
|
||||
ops.cutlass_fp4_moe_mm(
|
||||
|
||||
Reference in New Issue
Block a user