[Kernel] Support MulAndSilu (#11624)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-01-15 10:29:53 +08:00
committed by GitHub
parent a3a3ee4e6f
commit 42f5e7c52a
7 changed files with 83 additions and 36 deletions

View File

@@ -87,6 +87,41 @@ class SiluAndMul(CustomOp):
return out
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.mul_and_silu
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.silu_and_mul
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return x[..., :d] * F.silu(x[..., d:])
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out
# TODO implement forward_xpu for MulAndSilu
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.