Allow oot custom compiler extension via CompilerInterface (#28623)

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Icey
2025-11-25 15:25:15 +08:00
committed by GitHub
parent fe3a4f5b34
commit 888152bf87
3 changed files with 42 additions and 24 deletions

View File

@@ -63,13 +63,14 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
else:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
else:
assert compilation_config.backend == "eager", (
"Custom backends not supported with CompilationMode.VLLM_COMPILE"
)
elif compilation_config.backend == "eager":
logger.debug("Using EagerAdaptor")
return EagerAdaptor()
else:
logger.debug("Using custom backend: %s", compilation_config.backend)
compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
assert isinstance(compiler, CompilerInterface)
return compiler
class CompilerManager:
@@ -545,7 +546,10 @@ class VllmBackend:
self.prefix = prefix or model_tag
# Passes to run on the graph post-grad.
self.post_grad_pass_manager = PostGradPassManager()
self.pass_manager = resolve_obj_by_qualname(
current_platform.get_pass_manager_cls()
)()
self.pass_key = current_platform.pass_key
self.sym_tensor_indices = []
self.input_buffers = []
@@ -562,24 +566,20 @@ class VllmBackend:
def configure_post_pass(self):
config = self.compilation_config
self.post_grad_pass_manager.configure(self.vllm_config)
self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
inductor_config = config.inductor_compile_config
PASS_KEY = "post_grad_custom_post_pass"
if PASS_KEY in inductor_config:
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
if self.pass_key in inductor_config:
if isinstance(inductor_config[self.pass_key], PostGradPassManager):
# PassManager already added to config, make sure it's correct
assert (
inductor_config[PASS_KEY].uuid()
== self.post_grad_pass_manager.uuid()
)
assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid()
else:
# Config should automatically wrap all inductor passes
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager
assert isinstance(inductor_config[self.pass_key], InductorPass)
self.pass_manager.add(inductor_config[self.pass_key])
inductor_config[self.pass_key] = self.pass_manager
def __call__(
self, graph: fx.GraphModule, example_inputs