[optimization] remove python function call for custom op (#11750)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -35,10 +35,6 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# activation ops
|
# activation ops
|
||||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
||||||
torch.ops._C.silu_and_mul(out, x)
|
|
||||||
|
|
||||||
|
|
||||||
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
torch.ops._C.gelu_and_mul(out, x)
|
torch.ops._C.gelu_and_mul(out, x)
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import LazyDict
|
from vllm.utils import LazyDict
|
||||||
|
|
||||||
|
|
||||||
@@ -58,27 +59,31 @@ class SiluAndMul(CustomOp):
|
|||||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
self.op = torch.ops._C.silu_and_mul
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
self.op = ipex.llm.functional.silu_and_mul
|
||||||
|
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
return F.silu(x[..., :d]) * x[..., d:]
|
return F.silu(x[..., :d]) * x[..., d:]
|
||||||
|
|
||||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
ops.silu_and_mul(out, x)
|
self.op(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from vllm._ipex_ops import ipex_ops as ops
|
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
ops.silu_and_mul(out, x)
|
self.op(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
@@ -301,7 +300,8 @@ def fused_marlin_moe(
|
|||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||||
|
intermediate_cache1.view(-1, 2 * N))
|
||||||
|
|
||||||
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
|
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
|
||||||
intermediate_cache2,
|
intermediate_cache2,
|
||||||
|
|||||||
@@ -753,7 +753,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||||
|
intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
invoke_fused_moe_kernel(intermediate_cache2,
|
invoke_fused_moe_kernel(intermediate_cache2,
|
||||||
w2,
|
w2,
|
||||||
|
|||||||
Reference in New Issue
Block a user