[torch.compile] directly register custom op (#9896)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-10-31 21:56:09 -07:00
committed by GitHub
parent 031a7995f3
commit 96e0c9cbbd
9 changed files with 192 additions and 67 deletions

View File

@@ -12,6 +12,7 @@ import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
@@ -466,8 +467,6 @@ def get_config_dtype_str(dtype: torch.dtype,
return None
@torch.library.custom_op("vllm::inplace_fused_experts",
mutates_args=["hidden_states"])
def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@@ -484,22 +483,29 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale, a2_scale)
@inplace_fused_experts.register_fake
def _(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> None:
def inplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> None:
pass
@torch.library.custom_op("vllm::outplace_fused_experts", mutates_args=[])
direct_register_custom_op(
op_name="inplace_fused_experts",
op_func=inplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake,
)
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -517,21 +523,29 @@ def outplace_fused_experts(
w2_scale, a1_scale, a2_scale)
@outplace_fused_experts.register_fake
def _(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
def outplace_fused_experts_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="outplace_fused_experts",
op_func=outplace_fused_experts,
mutates_args=[],
fake_impl=outplace_fused_experts_fake,
)
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,