[torch.compile] directly register custom op (#9896)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -32,6 +32,7 @@ import torch
|
||||
import torch.types
|
||||
import yaml
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from typing_extensions import ParamSpec, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -1512,3 +1513,47 @@ def weak_ref_tensors(
|
||||
if isinstance(tensors, tuple):
|
||||
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
def is_in_doc_build() -> bool:
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
return isinstance(torch, _MockModule)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
# create a library to hold the custom op
|
||||
vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def direct_register_custom_op(
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: List[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
):
|
||||
"""
|
||||
`torch.library.custom_op` can have significant overhead because it
|
||||
needs to consider complicated dispatching logic. This function
|
||||
directly registers a custom op and dispatches it to the CUDA backend.
|
||||
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
|
||||
for more details.
|
||||
|
||||
By default, the custom op is registered to the vLLM library. If you
|
||||
want to register it to a different library, you can pass the library
|
||||
object to the `target_lib` argument.
|
||||
|
||||
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
||||
library object. If you want to bind the operator to a different library,
|
||||
make sure the library object is alive when the operator is used.
|
||||
"""
|
||||
if is_in_doc_build():
|
||||
return
|
||||
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||
my_lib = target_lib or vllm_lib
|
||||
my_lib.define(op_name + schema_str)
|
||||
my_lib.impl(op_name, op_func, "CUDA")
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
|
||||
Reference in New Issue
Block a user