Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -32,8 +32,11 @@ class CustomOp(nn.Module):
|
||||
op_cls_to_instantiate = cls
|
||||
else:
|
||||
op_cls_to_instantiate = cls.op_registry_oot[op_name]
|
||||
logger.debug("Instantiating custom op: %s using %s", op_name,
|
||||
str(op_cls_to_instantiate))
|
||||
logger.debug(
|
||||
"Instantiating custom op: %s using %s",
|
||||
op_name,
|
||||
str(op_cls_to_instantiate),
|
||||
)
|
||||
return super().__new__(op_cls_to_instantiate)
|
||||
|
||||
def __init__(self):
|
||||
@@ -86,8 +89,7 @@ class CustomOp(nn.Module):
|
||||
if enabled:
|
||||
compilation_config.enabled_custom_ops.update([self.__class__.name])
|
||||
else:
|
||||
compilation_config.disabled_custom_ops.update(
|
||||
[self.__class__.name])
|
||||
compilation_config.disabled_custom_ops.update([self.__class__.name])
|
||||
|
||||
if not enabled:
|
||||
return self.forward_native
|
||||
@@ -119,8 +121,7 @@ class CustomOp(nn.Module):
|
||||
|
||||
enabled = f"+{cls.name}" in custom_ops
|
||||
disabled = f"-{cls.name}" in custom_ops
|
||||
assert not (enabled
|
||||
and disabled), f"Cannot enable and disable {cls.name}"
|
||||
assert not (enabled and disabled), f"Cannot enable and disable {cls.name}"
|
||||
|
||||
return (CustomOp.default_on() or enabled) and not disabled
|
||||
|
||||
@@ -131,9 +132,12 @@ class CustomOp(nn.Module):
|
||||
Specifying 'all' or 'none' in custom_op takes precedence.
|
||||
"""
|
||||
from vllm.config import CompilationLevel
|
||||
|
||||
compilation_config = get_cached_compilation_config()
|
||||
default_on = (compilation_config.level < CompilationLevel.PIECEWISE
|
||||
or not compilation_config.use_inductor)
|
||||
default_on = (
|
||||
compilation_config.level < CompilationLevel.PIECEWISE
|
||||
or not compilation_config.use_inductor
|
||||
)
|
||||
count_none = compilation_config.custom_ops.count("none")
|
||||
count_all = compilation_config.custom_ops.count("all")
|
||||
return default_on and not count_none > 0 or count_all > 0
|
||||
@@ -143,13 +147,12 @@ class CustomOp(nn.Module):
|
||||
# Examples:
|
||||
# - MyOp.enabled()
|
||||
# - op_registry["my_op"].enabled()
|
||||
op_registry: dict[str, type['CustomOp']] = {}
|
||||
op_registry_oot: dict[str, type['CustomOp']] = {}
|
||||
op_registry: dict[str, type["CustomOp"]] = {}
|
||||
op_registry_oot: dict[str, type["CustomOp"]] = {}
|
||||
|
||||
# Decorator to register custom ops.
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
|
||||
def decorator(op_cls):
|
||||
assert name not in cls.op_registry, f"Duplicate op name: {name}"
|
||||
op_cls.name = name
|
||||
@@ -169,11 +172,9 @@ class CustomOp(nn.Module):
|
||||
# - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
|
||||
@classmethod
|
||||
def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None):
|
||||
|
||||
def decorator(op_cls):
|
||||
reg_name = name if name is not None else cls.__name__
|
||||
assert reg_name not in cls.op_registry_oot, \
|
||||
f"Duplicate op name: {reg_name}"
|
||||
assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}"
|
||||
op_cls.name = reg_name
|
||||
cls.op_registry_oot[reg_name] = op_cls
|
||||
return op_cls
|
||||
|
||||
Reference in New Issue
Block a user