Allow oot custom compiler extension via CompilerInterface (#28623)
Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -63,13 +63,14 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
else:
|
||||
logger.debug("Using InductorAdaptor")
|
||||
return InductorAdaptor()
|
||||
else:
|
||||
assert compilation_config.backend == "eager", (
|
||||
"Custom backends not supported with CompilationMode.VLLM_COMPILE"
|
||||
)
|
||||
|
||||
elif compilation_config.backend == "eager":
|
||||
logger.debug("Using EagerAdaptor")
|
||||
return EagerAdaptor()
|
||||
else:
|
||||
logger.debug("Using custom backend: %s", compilation_config.backend)
|
||||
compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
|
||||
assert isinstance(compiler, CompilerInterface)
|
||||
return compiler
|
||||
|
||||
|
||||
class CompilerManager:
|
||||
@@ -545,7 +546,10 @@ class VllmBackend:
|
||||
self.prefix = prefix or model_tag
|
||||
|
||||
# Passes to run on the graph post-grad.
|
||||
self.post_grad_pass_manager = PostGradPassManager()
|
||||
self.pass_manager = resolve_obj_by_qualname(
|
||||
current_platform.get_pass_manager_cls()
|
||||
)()
|
||||
self.pass_key = current_platform.pass_key
|
||||
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
@@ -562,24 +566,20 @@ class VllmBackend:
|
||||
|
||||
def configure_post_pass(self):
|
||||
config = self.compilation_config
|
||||
self.post_grad_pass_manager.configure(self.vllm_config)
|
||||
self.pass_manager.configure(self.vllm_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||
inductor_config = config.inductor_compile_config
|
||||
PASS_KEY = "post_grad_custom_post_pass"
|
||||
if PASS_KEY in inductor_config:
|
||||
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
|
||||
if self.pass_key in inductor_config:
|
||||
if isinstance(inductor_config[self.pass_key], PostGradPassManager):
|
||||
# PassManager already added to config, make sure it's correct
|
||||
assert (
|
||||
inductor_config[PASS_KEY].uuid()
|
||||
== self.post_grad_pass_manager.uuid()
|
||||
)
|
||||
assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid()
|
||||
else:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
||||
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
||||
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
||||
assert isinstance(inductor_config[self.pass_key], InductorPass)
|
||||
self.pass_manager.add(inductor_config[self.pass_key])
|
||||
inductor_config[self.pass_key] = self.pass_manager
|
||||
|
||||
def __call__(
|
||||
self, graph: fx.GraphModule, example_inputs
|
||||
|
||||
Reference in New Issue
Block a user