[Perf][Kernel] Fused SiLU+Mul+Quant kernel for NVFP4 cutlass_moe (#31832)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-01-09 09:40:33 -05:00
committed by GitHub
parent 8e27663b6a
commit 34cd32fe30
8 changed files with 278 additions and 95 deletions

View File

@@ -239,4 +239,34 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
return e2m1Vec;
}
// silu in float32
__device__ __forceinline__ float silu(float x) {
return __fdividef(x, (1.f + __expf(-x)));
}
__device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y));
}
template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu_mul(
const PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
PackedVec<Type> result;
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
// silu_mul in float32
if constexpr (std::is_same_v<Type, half>) {
float2 silu_vec = silu2(__half22float2(x_vec.elts[i]));
result.elts[i] = __float22half2_rn(
__fmul2_rn(silu_vec, __half22float2(y_vec.elts[i])));
} else {
float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i]));
result.elts[i] = __float22bfloat162_rn(
__fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i])));
}
}
return result;
}
} // namespace vllm