fix(compile): apply partition wrapper when loading AOT cached functions (#31536)
Signed-off-by: Devbyteai <abud6673@gmail.com> Signed-off-by: DevByteAI <161969603+devbyteai@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -371,9 +371,12 @@ def _support_torch_compile(
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# if aot_compiled_fn is set, just call it.
|
||||
# if aot_compiled_fn is set, call it with partition wrapper context.
|
||||
# The partition wrapper must be active at runtime for CUDA graph
|
||||
# capture to work correctly with inductor graph partitioning.
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
ds_type = self.compilation_config.dynamic_shapes_config.type
|
||||
cache_dir = None
|
||||
@@ -432,7 +435,9 @@ def _support_torch_compile(
|
||||
logger.info(
|
||||
"Directly load AOT compilation from path %s", aot_compilation_path
|
||||
)
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
# Apply partition wrapper context for proper CUDA graph capture
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
if self.compiled:
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user