Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -23,8 +23,7 @@ class LazyInitPass(InductorPass):
|
||||
and then immediately invoke it.
|
||||
"""
|
||||
|
||||
def __init__(self, pass_cls: type[VllmInductorPass],
|
||||
vllm_config: VllmConfig):
|
||||
def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
|
||||
self.pass_cls = pass_cls
|
||||
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
|
||||
|
||||
@@ -45,20 +44,18 @@ class TestBackend:
|
||||
Inductor config is default-initialized from VllmConfig.CompilationConfig.
|
||||
"""
|
||||
|
||||
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
|
||||
None]]):
|
||||
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
|
||||
self.custom_passes = list(passes)
|
||||
compile_config = get_current_vllm_config().compilation_config
|
||||
self.inductor_config = compile_config.inductor_compile_config
|
||||
self.inductor_config['force_disable_caches'] = True
|
||||
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
|
||||
self.inductor_config["force_disable_caches"] = True
|
||||
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs):
|
||||
self.graph_pre_compile = deepcopy(graph)
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
return compile_fx(graph,
|
||||
example_inputs,
|
||||
config_patches=self.inductor_config)
|
||||
|
||||
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
|
||||
|
||||
@with_pattern_match_debug
|
||||
def post_pass(self, graph: fx.Graph):
|
||||
@@ -82,8 +79,7 @@ class TestBackend:
|
||||
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
|
||||
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
|
||||
if fully_replaced:
|
||||
assert num_post == 0, \
|
||||
f"Unexpected op {op.name()} in post-pass graph"
|
||||
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
|
||||
|
||||
def check_after_ops(self, ops: Sequence[OpOverload]):
|
||||
for op in ops:
|
||||
|
||||
Reference in New Issue
Block a user