[optimization] remove python function call for custom op (#11750)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-08 01:04:28 +08:00
committed by GitHub
parent c0efe92d8b
commit 869579a702
4 changed files with 15 additions and 13 deletions

View File

@@ -35,10 +35,6 @@ else:
# activation ops # activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul(out, x)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x) torch.ops._C.gelu_and_mul(out, x)

View File

@@ -10,6 +10,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import LazyDict from vllm.utils import LazyDict
@@ -58,27 +59,31 @@ class SiluAndMul(CustomOp):
return: (num_tokens, d) or (batch_size, seq_len, d) return: (num_tokens, d) or (batch_size, seq_len, d)
""" """
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
self.op = ipex.llm.functional.silu_and_mul
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2 d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:] return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, )) output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x) self.op(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, )) output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x) self.op(out, x)
return out return out

View File

@@ -4,7 +4,6 @@ from typing import Optional
import torch import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config) fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
@@ -301,7 +300,8 @@ def fused_marlin_moe(
False, False,
) )
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2, intermediate_cache2,

View File

@@ -753,7 +753,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape) block_shape=block_shape)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2, invoke_fused_moe_kernel(intermediate_cache2,
w2, w2,