[BugFix] fix VLLM_USE_STANDALONE_COMPILE=0 (#38015)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -441,6 +441,37 @@ def test_partition_wrapper_applied_on_aot_load(
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_standalone_compile_correctness():
|
||||
"""Outputs must match regardless of VLLM_USE_STANDALONE_COMPILE."""
|
||||
import json
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
compilation_config = json.dumps(
|
||||
{
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
}
|
||||
)
|
||||
|
||||
common_args = [
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--max-model-len",
|
||||
"256",
|
||||
"--compilation_config",
|
||||
compilation_config,
|
||||
]
|
||||
|
||||
compare_two_settings(
|
||||
"facebook/opt-125m",
|
||||
common_args,
|
||||
common_args,
|
||||
env1={"VLLM_USE_STANDALONE_COMPILE": "1"},
|
||||
env2={"VLLM_USE_STANDALONE_COMPILE": "0"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
@@ -632,6 +632,23 @@ class InductorAdaptor(CompilerInterface):
|
||||
)
|
||||
stack.enter_context(_patch_constrain_to_fx_strides())
|
||||
|
||||
# Clear the tracing context before calling compile_fx.
|
||||
# vLLM calls compile_fx from within a PiecewiseCompileInterpreter
|
||||
# that runs under Dynamo's tracing context. The tracing context
|
||||
# has a FakeTensorMode from Dynamo, but the example inputs for
|
||||
# this subgraph have fake tensors from a different FakeTensorMode.
|
||||
# compile_fx's _compile_fx_main calls detect_fake_mode() which
|
||||
# asserts all FakeTensorModes match, causing a crash.
|
||||
# Clearing the tracing context lets compile_fx create its own.
|
||||
saved_tracing_context = torch._guards.TracingContext.try_get()
|
||||
if saved_tracing_context is not None:
|
||||
torch._guards._TLS.tracing_context = None
|
||||
|
||||
def _restore_tracing_context():
|
||||
torch._guards._TLS.tracing_context = saved_tracing_context
|
||||
|
||||
stack.callback(_restore_tracing_context)
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
|
||||
Reference in New Issue
Block a user