[torch.compile] Use FakeTensors instead of real GPU tensors for single-size compilation (#36093)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -348,13 +348,39 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
# Can remove this after the following issue gets fixed
|
||||
# https://github.com/pytorch/pytorch/issues/174502
|
||||
if envs.VLLM_ENABLE_PREGRAD_PASSES:
|
||||
ctx: Any = contextlib.nullcontext()
|
||||
pregrad_ctx: Any = contextlib.nullcontext()
|
||||
else:
|
||||
ctx = patch(
|
||||
pregrad_ctx = patch(
|
||||
"torch._inductor.compile_fx._recursive_pre_grad_passes",
|
||||
lambda gm, _: gm,
|
||||
)
|
||||
with ctx, _patch_constrain_to_fx_strides():
|
||||
|
||||
# When inputs are FakeTensors (from create_concrete_args),
|
||||
# standalone_compile("from_example_inputs") would normally create
|
||||
# a fresh FakeTensorMode, causing a mode mismatch assertion.
|
||||
# Patch FakeTensorMode in standalone_compile so it reuses the
|
||||
# mode already attached to our FakeTensors. This gives us both
|
||||
# ignore_shape_env=True (from "from_example_inputs") and mode
|
||||
# consistency (from reusing our mode).
|
||||
# Can remove this after the following issue gets fixed:
|
||||
# https://github.com/pytorch/pytorch/issues/176562
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
input_fake_mode = None
|
||||
for x in example_inputs:
|
||||
if isinstance(x, FakeTensor):
|
||||
input_fake_mode = x.fake_mode
|
||||
break
|
||||
|
||||
if input_fake_mode is not None:
|
||||
fake_mode_ctx: Any = patch(
|
||||
"torch._inductor.standalone_compile.FakeTensorMode",
|
||||
lambda *a, **kw: input_fake_mode,
|
||||
)
|
||||
else:
|
||||
fake_mode_ctx = contextlib.nullcontext()
|
||||
|
||||
with pregrad_ctx, fake_mode_ctx, _patch_constrain_to_fx_strides():
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
|
||||
if use_aot:
|
||||
|
||||
Reference in New Issue
Block a user