[Kernel] add kernel for FATReLU (#9610)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-10-24 16:18:27 +08:00
committed by GitHub
parent 8a02cd045a
commit 295a061fb3
6 changed files with 78 additions and 8 deletions

View File

@@ -39,7 +39,13 @@ class FatreluAndMul(CustomOp):
return x1 * x2
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
from vllm import _custom_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.fatrelu_and_mul(out, x, self.threshold)
return out
@CustomOp.register("silu_and_mul")