fix(compile): apply partition wrapper when loading AOT cached functions (#31536)
Signed-off-by: Devbyteai <abud6673@gmail.com> Signed-off-by: DevByteAI <161969603+devbyteai@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import functools
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -24,6 +25,13 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_tmp_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
|
||||
"""Fixture that sets VLLM_CACHE_ROOT to a temporary directory."""
|
||||
monkeypatch.setenv("VLLM_CACHE_ROOT", str(tmp_path / "vllm_cache"))
|
||||
return tmp_path
|
||||
|
||||
|
||||
def reference_fn(x: torch.Tensor):
|
||||
assert x.shape[0] <= 42
|
||||
assert x.shape[0] % 2 == 0
|
||||
@@ -148,6 +156,93 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
||||
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_partition_wrapper_applied_on_aot_load(
|
||||
monkeypatch: pytest.MonkeyPatch, vllm_tmp_cache: Path, mocker
|
||||
):
|
||||
"""
|
||||
Test that partition wrappers are applied when loading AOT cached functions.
|
||||
|
||||
This test verifies the fix for GitHub issue #31439 where AOT compile
|
||||
caused 2x latency regression when use_inductor_graph_partition=True.
|
||||
The root cause was that partition wrapper context was bypassed when
|
||||
loading from AOT cache.
|
||||
"""
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
args = (torch.randn(10, 10),)
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
|
||||
# Create config with partition enabled
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
|
||||
# First compilation - save to cache
|
||||
with use_vllm_config(vllm_config):
|
||||
compiled_mod = CompiledMod(vllm_config=vllm_config)
|
||||
compiled_mod(*args)
|
||||
disable_envs_cache()
|
||||
|
||||
# Second run - load from cache, verify partition wrapper applied
|
||||
monkeypatch.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
|
||||
# Use mocker to spy on set_customized_partition_wrappers
|
||||
spy = mocker.spy(torch._inductor.utils, "set_customized_partition_wrappers")
|
||||
|
||||
with use_vllm_config(vllm_config):
|
||||
compiled_mod = CompiledMod(vllm_config=vllm_config)
|
||||
|
||||
# First call after restart: loads from AOT cache.
|
||||
# This tests the fix for the first call after a restart.
|
||||
compiled_mod(*args)
|
||||
|
||||
# Verify partition wrapper was called on AOT load.
|
||||
assert spy.call_count >= 2, (
|
||||
"Expected partition wrapper to be set and cleared on AOT load, "
|
||||
f"got {spy.call_count} calls"
|
||||
)
|
||||
# First call should set a wrapper, last call should clear it
|
||||
assert spy.call_args_list[0][0][0] is not None, (
|
||||
"First call on AOT load should set a wrapper function"
|
||||
)
|
||||
assert spy.call_args_list[-1][0][0] is None, (
|
||||
"Last call on AOT load should clear the wrapper"
|
||||
)
|
||||
|
||||
# Reset for the next check.
|
||||
spy.reset_mock()
|
||||
|
||||
# Subsequent call: uses the cached `aot_compiled_fn`.
|
||||
# This tests the fix for subsequent calls.
|
||||
compiled_mod(*args)
|
||||
|
||||
# Verify partition wrapper was called on the subsequent call.
|
||||
assert spy.call_count >= 2, (
|
||||
"Expected partition wrapper set and cleared on subsequent "
|
||||
f"call, got {spy.call_count} calls"
|
||||
)
|
||||
assert spy.call_args_list[0][0][0] is not None, (
|
||||
"First call on subsequent call should set a wrapper function"
|
||||
)
|
||||
assert spy.call_args_list[-1][0][0] is None, (
|
||||
"Last call on subsequent call should clear the wrapper"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
|
||||
@@ -371,9 +371,12 @@ def _support_torch_compile(
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# if aot_compiled_fn is set, just call it.
|
||||
# if aot_compiled_fn is set, call it with partition wrapper context.
|
||||
# The partition wrapper must be active at runtime for CUDA graph
|
||||
# capture to work correctly with inductor graph partitioning.
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
ds_type = self.compilation_config.dynamic_shapes_config.type
|
||||
cache_dir = None
|
||||
@@ -432,7 +435,9 @@ def _support_torch_compile(
|
||||
logger.info(
|
||||
"Directly load AOT compilation from path %s", aot_compilation_path
|
||||
)
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
# Apply partition wrapper context for proper CUDA graph capture
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
if self.compiled:
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user