[compile] Fix CI for test_gpt2_cache_hit (#30902)
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from contextlib import contextmanager
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.activation
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
@@ -16,9 +17,12 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.envs import disable_envs_cache
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
def reference_fn(x: torch.Tensor):
|
||||
assert x.shape[0] <= 42
|
||||
@@ -66,6 +70,7 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
|
||||
torch.compiler.set_stance("fail_on_recompile"),
|
||||
):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
disable_envs_cache()
|
||||
|
||||
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
torch._dynamo.reset()
|
||||
@@ -101,6 +106,7 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
|
||||
vllm_config = make_vllm_config()
|
||||
with use_vllm_config(vllm_config):
|
||||
expected = CompiledMod(vllm_config=vllm_config)(*args)
|
||||
disable_envs_cache()
|
||||
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
vllm_config = make_vllm_config()
|
||||
@@ -130,6 +136,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
||||
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
||||
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
||||
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||
disable_envs_cache()
|
||||
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
vllm_config = make_vllm_config()
|
||||
@@ -144,7 +151,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
@use_vllm_config(make_vllm_config())
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that compiling gpt2 twice results in a cache hit and
|
||||
@@ -186,6 +193,8 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
# Clean up first model
|
||||
del llm_model
|
||||
disable_envs_cache()
|
||||
vllm.model_executor.layers.activation._ACTIVATION_REGISTRY._dict.clear()
|
||||
|
||||
# Second compilation - should hit cache
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
|
||||
Reference in New Issue
Block a user