fix(compile): apply partition wrapper when loading AOT cached functions (#31536)

Signed-off-by: Devbyteai <abud6673@gmail.com>
Signed-off-by: DevByteAI <161969603+devbyteai@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
DevByteAI
2026-01-08 11:27:26 +02:00
committed by GitHub
parent 8cbdc7eb94
commit 1f214290d6
2 changed files with 103 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ import functools
import multiprocessing
import tempfile
from contextlib import contextmanager
from pathlib import Path
import pytest
import torch
@@ -24,6 +25,13 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test
@pytest.fixture
def vllm_tmp_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
"""Fixture that sets VLLM_CACHE_ROOT to a temporary directory."""
monkeypatch.setenv("VLLM_CACHE_ROOT", str(tmp_path / "vllm_cache"))
return tmp_path
def reference_fn(x: torch.Tensor):
assert x.shape[0] <= 42
assert x.shape[0] % 2 == 0
@@ -148,6 +156,93 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_partition_wrapper_applied_on_aot_load(
monkeypatch: pytest.MonkeyPatch, vllm_tmp_cache: Path, mocker
):
"""
Test that partition wrappers are applied when loading AOT cached functions.
This test verifies the fix for GitHub issue #31439 where AOT compile
caused 2x latency regression when use_inductor_graph_partition=True.
The root cause was that partition wrapper context was bypassed when
loading from AOT cache.
"""
from vllm.config import CUDAGraphMode
args = (torch.randn(10, 10),)
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
# Create config with partition enabled
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# First compilation - save to cache
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
disable_envs_cache()
# Second run - load from cache, verify partition wrapper applied
monkeypatch.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# Use mocker to spy on set_customized_partition_wrappers
spy = mocker.spy(torch._inductor.utils, "set_customized_partition_wrappers")
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
# First call after restart: loads from AOT cache.
# This tests the fix for the first call after a restart.
compiled_mod(*args)
# Verify partition wrapper was called on AOT load.
assert spy.call_count >= 2, (
"Expected partition wrapper to be set and cleared on AOT load, "
f"got {spy.call_count} calls"
)
# First call should set a wrapper, last call should clear it
assert spy.call_args_list[0][0][0] is not None, (
"First call on AOT load should set a wrapper function"
)
assert spy.call_args_list[-1][0][0] is None, (
"Last call on AOT load should clear the wrapper"
)
# Reset for the next check.
spy.reset_mock()
# Subsequent call: uses the cached `aot_compiled_fn`.
# This tests the fix for subsequent calls.
compiled_mod(*args)
# Verify partition wrapper was called on the subsequent call.
assert spy.call_count >= 2, (
"Expected partition wrapper set and cleared on subsequent "
f"call, got {spy.call_count} calls"
)
assert spy.call_args_list[0][0][0] is not None, (
"First call on subsequent call should set a wrapper function"
)
assert spy.call_args_list[-1][0][0] is None, (
"Last call on subsequent call should clear the wrapper"
)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)

View File

@@ -371,9 +371,12 @@ def _support_torch_compile(
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)
# if aot_compiled_fn is set, just call it.
# if aot_compiled_fn is set, call it with partition wrapper context.
# The partition wrapper must be active at runtime for CUDA graph
# capture to work correctly with inductor graph partitioning.
if getattr(self, "aot_compiled_fn", None) is not None:
return self.aot_compiled_fn(self, *args, **kwargs)
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs)
ds_type = self.compilation_config.dynamic_shapes_config.type
cache_dir = None
@@ -432,7 +435,9 @@ def _support_torch_compile(
logger.info(
"Directly load AOT compilation from path %s", aot_compilation_path
)
return self.aot_compiled_fn(self, *args, **kwargs)
# Apply partition wrapper context for proper CUDA graph capture
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs)
if self.compiled:
assert (