[Kernel][CPU] Add Quick gelu to CPU (#5717)
This commit is contained in:
@@ -59,6 +59,13 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
|
||||
return w3 * x * (ones + t);
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) {
|
||||
const vec_op::FP32Vec8 zeros(0.0);
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(1.702f);
|
||||
return x / (ones + (zeros - w1 * x).exp());
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(M_SQRT1_2);
|
||||
@@ -142,3 +149,15 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
|
||||
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input) {
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(gelu_quick_impl)
|
||||
activation_kernel<scalar_t, gelu_quick_act, false>(
|
||||
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(gelu_quick_impl)
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user