Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -23,8 +23,7 @@ class LazyInitPass(InductorPass):
and then immediately invoke it.
"""
def __init__(self, pass_cls: type[VllmInductorPass],
vllm_config: VllmConfig):
def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
self.pass_cls = pass_cls
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
@@ -45,20 +44,18 @@ class TestBackend:
Inductor config is default-initialized from VllmConfig.CompilationConfig.
"""
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]):
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
self.custom_passes = list(passes)
compile_config = get_current_vllm_config().compilation_config
self.inductor_config = compile_config.inductor_compile_config
self.inductor_config['force_disable_caches'] = True
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
self.inductor_config["force_disable_caches"] = True
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.inductor_config)
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
@with_pattern_match_debug
def post_pass(self, graph: fx.Graph):
@@ -82,8 +79,7 @@ class TestBackend:
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced:
assert num_post == 0, \
f"Unexpected op {op.name()} in post-pass graph"
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops: Sequence[OpOverload]):
for op in ops: