fix(compile): apply partition wrapper when loading AOT cached functions (#31536)

Signed-off-by: Devbyteai <abud6673@gmail.com>
Signed-off-by: DevByteAI <161969603+devbyteai@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
DevByteAI
2026-01-08 11:27:26 +02:00
committed by GitHub
parent 8cbdc7eb94
commit 1f214290d6
2 changed files with 103 additions and 3 deletions

View File

@@ -371,9 +371,12 @@ def _support_torch_compile(
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)
# if aot_compiled_fn is set, just call it.
# if aot_compiled_fn is set, call it with partition wrapper context.
# The partition wrapper must be active at runtime for CUDA graph
# capture to work correctly with inductor graph partitioning.
if getattr(self, "aot_compiled_fn", None) is not None:
return self.aot_compiled_fn(self, *args, **kwargs)
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs)
ds_type = self.compilation_config.dynamic_shapes_config.type
cache_dir = None
@@ -432,7 +435,9 @@ def _support_torch_compile(
logger.info(
"Directly load AOT compilation from path %s", aot_compilation_path
)
return self.aot_compiled_fn(self, *args, **kwargs)
# Apply partition wrapper context for proper CUDA graph capture
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs)
if self.compiled:
assert (