[BugFix] fix VLLM_USE_STANDALONE_COMPILE=0 (#38015)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -632,6 +632,23 @@ class InductorAdaptor(CompilerInterface):
|
||||
)
|
||||
stack.enter_context(_patch_constrain_to_fx_strides())
|
||||
|
||||
# Clear the tracing context before calling compile_fx.
|
||||
# vLLM calls compile_fx from within a PiecewiseCompileInterpreter
|
||||
# that runs under Dynamo's tracing context. The tracing context
|
||||
# has a FakeTensorMode from Dynamo, but the example inputs for
|
||||
# this subgraph have fake tensors from a different FakeTensorMode.
|
||||
# compile_fx's _compile_fx_main calls detect_fake_mode() which
|
||||
# asserts all FakeTensorModes match, causing a crash.
|
||||
# Clearing the tracing context lets compile_fx create its own.
|
||||
saved_tracing_context = torch._guards.TracingContext.try_get()
|
||||
if saved_tracing_context is not None:
|
||||
torch._guards._TLS.tracing_context = None
|
||||
|
||||
def _restore_tracing_context():
|
||||
torch._guards._TLS.tracing_context = saved_tracing_context
|
||||
|
||||
stack.callback(_restore_tracing_context)
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
|
||||
Reference in New Issue
Block a user