[Perf][Kernel] Fused SiLU+Mul+Quant kernel for NVFP4 cutlass_moe (#31832)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -239,4 +239,34 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
|
||||
return e2m1Vec;
|
||||
}
|
||||
|
||||
// silu in float32
|
||||
__device__ __forceinline__ float silu(float x) {
|
||||
return __fdividef(x, (1.f + __expf(-x)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float2 silu2(float2 x) {
|
||||
return make_float2(silu(x.x), silu(x.y));
|
||||
}
|
||||
|
||||
template <class Type>
|
||||
__inline__ __device__ PackedVec<Type> compute_silu_mul(
|
||||
const PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
|
||||
PackedVec<Type> result;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
|
||||
// silu_mul in float32
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
float2 silu_vec = silu2(__half22float2(x_vec.elts[i]));
|
||||
result.elts[i] = __float22half2_rn(
|
||||
__fmul2_rn(silu_vec, __half22float2(y_vec.elts[i])));
|
||||
} else {
|
||||
float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i]));
|
||||
result.elts[i] = __float22bfloat162_rn(
|
||||
__fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i])));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
Reference in New Issue
Block a user