[torch.compile] Use FakeTensors instead of real GPU tensors for single-size compilation (#36093)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-03-11 12:07:09 -04:00
committed by GitHub
parent bea02cdf93
commit 822e250ab7
3 changed files with 137 additions and 25 deletions

View File

@@ -127,6 +127,88 @@ def test_compile_config_get_compile_ranges():
]
class PostGradStaticShapeChecker(InductorPass):
"""Asserts that compile_sizes entries produce graphs with fully concrete
(non-symbolic) shapes, and compile_ranges entries have symbolic shapes."""
def __init__(self):
self.num_static_calls = 0
self.num_dynamic_calls = 0
def __call__(self, graph: fx.Graph):
from torch.fx.experimental.symbolic_shapes import is_symbolic
compile_range = get_pass_context().compile_range
is_single = compile_range.is_single_size()
for node in graph.nodes:
val = node.meta.get("val")
if val is None:
val = node.meta.get("example_value")
if isinstance(val, torch.Tensor):
has_symbolic = any(is_symbolic(d) for d in val.shape)
if is_single:
assert not has_symbolic, (
f"compile_sizes entry {compile_range}: "
f"node '{node.name}' has symbolic shape "
f"{val.shape}"
)
else:
# compile_ranges should have at least some
# symbolic shapes (the batch dimension)
if has_symbolic:
self.num_dynamic_calls += 1
return
if is_single:
self.num_static_calls += 1
def uuid(self) -> str:
state: dict[str, Any] = {}
return InductorPass.hash_dict(state)
def test_compile_sizes_produce_static_shapes(use_fresh_inductor_cache):
"""Verify that compile_sizes entries are compiled with fully concrete
shapes (no SymInts), while compile_ranges entries retain dynamic shapes."""
checker = PostGradStaticShapeChecker()
torch.set_default_device("cuda")
vllm_config = VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_endpoints=[8],
compile_sizes=[16],
inductor_compile_config={
"post_grad_custom_post_pass": checker,
},
),
)
with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix="").eval()
# 3 compilations: Range(1,8), Range(9,8192), single-size 16
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=1,
num_backend_compilations=3,
):
run_model(vllm_config, model, [1, 16, 64])
# compile_sizes=16 should produce static shapes
assert checker.num_static_calls == 1, (
f"Expected 1 static compilation, got {checker.num_static_calls}"
)
# compile_ranges should produce dynamic shapes
assert checker.num_dynamic_calls == 2, (
f"Expected 2 dynamic compilations, got {checker.num_dynamic_calls}"
)
def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
# To force multiple compilations, we disable the compile cache
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

View File

@@ -348,13 +348,39 @@ class InductorStandaloneAdaptor(CompilerInterface):
# Can remove this after the following issue gets fixed
# https://github.com/pytorch/pytorch/issues/174502
if envs.VLLM_ENABLE_PREGRAD_PASSES:
ctx: Any = contextlib.nullcontext()
pregrad_ctx: Any = contextlib.nullcontext()
else:
ctx = patch(
pregrad_ctx = patch(
"torch._inductor.compile_fx._recursive_pre_grad_passes",
lambda gm, _: gm,
)
with ctx, _patch_constrain_to_fx_strides():
# When inputs are FakeTensors (from create_concrete_args),
# standalone_compile("from_example_inputs") would normally create
# a fresh FakeTensorMode, causing a mode mismatch assertion.
# Patch FakeTensorMode in standalone_compile so it reuses the
# mode already attached to our FakeTensors. This gives us both
# ignore_shape_env=True (from "from_example_inputs") and mode
# consistency (from reusing our mode).
# Can remove this after the following issue gets fixed:
# https://github.com/pytorch/pytorch/issues/176562
from torch._subclasses.fake_tensor import FakeTensor
input_fake_mode = None
for x in example_inputs:
if isinstance(x, FakeTensor):
input_fake_mode = x.fake_mode
break
if input_fake_mode is not None:
fake_mode_ctx: Any = patch(
"torch._inductor.standalone_compile.FakeTensorMode",
lambda *a, **kw: input_fake_mode,
)
else:
fake_mode_ctx = contextlib.nullcontext()
with pregrad_ctx, fake_mode_ctx, _patch_constrain_to_fx_strides():
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
if use_aot:

View File

@@ -34,13 +34,14 @@ def get_fake_args_from_graph(graph: fx.GraphModule) -> list[Any]:
def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]:
"""Create example inputs with symbolic dims replaced by a concrete size.
"""Create Fake example inputs with symbolic dims replaced by a concrete size.
Used for single-size eager compilation where we need concrete-shaped
inputs but don't have real runtime tensors yet.
Used for single-size compilation where we need concrete-shaped inputs.
The Dynamo-captured graph gives us example inputs with SymInts in them.
"""
from torch._prims_common import compute_required_storage_length
from torch.fx.experimental.symbolic_shapes import is_symbolic
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv, is_symbolic
def concretize(sym_val: Any) -> int:
"""Replace all symbolic variables in a SymInt expression with size."""
@@ -49,25 +50,28 @@ def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]:
expr = sym_val.node.expr
return int(expr.subs({s: size for s in expr.free_symbols}))
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
args: list[Any] = []
for node in graph.graph.nodes:
if node.op != "placeholder":
break
val = node.meta["example_value"]
if isinstance(val, torch.SymInt):
args.append(concretize(val))
elif isinstance(val, torch.Tensor):
new_shape = tuple(concretize(d) for d in val.shape)
new_strides = tuple(concretize(s) for s in val.stride())
new_storage_offset = concretize(val.storage_offset())
needed_size = compute_required_storage_length(
new_shape, new_strides, new_storage_offset
)
t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
t = t.as_strided(new_shape, new_strides, new_storage_offset)
args.append(t)
else:
args.append(val)
with fake_mode:
for node in graph.graph.nodes:
if node.op != "placeholder":
break
val = node.meta["example_value"]
if isinstance(val, torch.SymInt):
args.append(concretize(val))
elif isinstance(val, torch.Tensor):
new_shape = tuple(concretize(d) for d in val.shape)
new_strides = tuple(concretize(s) for s in val.stride())
new_storage_offset = concretize(val.storage_offset())
needed_size = compute_required_storage_length(
new_shape, new_strides, new_storage_offset
)
t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
t = t.as_strided(new_shape, new_strides, new_storage_offset)
args.append(t)
else:
args.append(val)
return args