[CustomOp] Support object-level enable for CustomOp (#30547)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -38,8 +38,9 @@ class CustomOp(nn.Module):
|
|||||||
)
|
)
|
||||||
return super().__new__(op_cls_to_instantiate)
|
return super().__new__(op_cls_to_instantiate)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, enforce_enable: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self._enforce_enable = enforce_enable
|
||||||
self._forward_method = self.dispatch_forward()
|
self._forward_method = self.dispatch_forward()
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@@ -84,7 +85,11 @@ class CustomOp(nn.Module):
|
|||||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||||
# specific backend. Currently, we do not support dynamic dispatching.
|
# specific backend. Currently, we do not support dynamic dispatching.
|
||||||
compilation_config = get_cached_compilation_config()
|
compilation_config = get_cached_compilation_config()
|
||||||
enabled = self.enabled()
|
|
||||||
|
# CustomOp object can be enforce enabled, e.g., enable device-specific
|
||||||
|
# kernels in ViT models when enabling graph mode. By default, it will
|
||||||
|
# follow the compilation_config to determine whether enable itself.
|
||||||
|
enabled = self._enforce_enable or self.enabled()
|
||||||
if enabled:
|
if enabled:
|
||||||
compilation_config.enabled_custom_ops.update([self.__class__.name])
|
compilation_config.enabled_custom_ops.update([self.__class__.name])
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user