[torch.compile] support all attention backends (#10558)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -1573,6 +1573,7 @@ def direct_register_custom_op(
|
||||
mutates_args: List[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
dispatch_key: str = "CUDA",
|
||||
):
|
||||
"""
|
||||
`torch.library.custom_op` can have significant overhead because it
|
||||
@@ -1601,7 +1602,7 @@ def direct_register_custom_op(
|
||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||
my_lib = target_lib or vllm_lib
|
||||
my_lib.define(op_name + schema_str)
|
||||
my_lib.impl(op_name, op_func, "CUDA")
|
||||
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user