[torch.compile] Use FakeTensors instead of real GPU tensors for single-size compilation (#36093)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -127,6 +127,88 @@ def test_compile_config_get_compile_ranges():
|
||||
]
|
||||
|
||||
|
||||
class PostGradStaticShapeChecker(InductorPass):
|
||||
"""Asserts that compile_sizes entries produce graphs with fully concrete
|
||||
(non-symbolic) shapes, and compile_ranges entries have symbolic shapes."""
|
||||
|
||||
def __init__(self):
|
||||
self.num_static_calls = 0
|
||||
self.num_dynamic_calls = 0
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
||||
|
||||
compile_range = get_pass_context().compile_range
|
||||
is_single = compile_range.is_single_size()
|
||||
|
||||
for node in graph.nodes:
|
||||
val = node.meta.get("val")
|
||||
if val is None:
|
||||
val = node.meta.get("example_value")
|
||||
if isinstance(val, torch.Tensor):
|
||||
has_symbolic = any(is_symbolic(d) for d in val.shape)
|
||||
if is_single:
|
||||
assert not has_symbolic, (
|
||||
f"compile_sizes entry {compile_range}: "
|
||||
f"node '{node.name}' has symbolic shape "
|
||||
f"{val.shape}"
|
||||
)
|
||||
else:
|
||||
# compile_ranges should have at least some
|
||||
# symbolic shapes (the batch dimension)
|
||||
if has_symbolic:
|
||||
self.num_dynamic_calls += 1
|
||||
return
|
||||
|
||||
if is_single:
|
||||
self.num_static_calls += 1
|
||||
|
||||
def uuid(self) -> str:
|
||||
state: dict[str, Any] = {}
|
||||
return InductorPass.hash_dict(state)
|
||||
|
||||
|
||||
def test_compile_sizes_produce_static_shapes(use_fresh_inductor_cache):
|
||||
"""Verify that compile_sizes entries are compiled with fully concrete
|
||||
shapes (no SymInts), while compile_ranges entries retain dynamic shapes."""
|
||||
checker = PostGradStaticShapeChecker()
|
||||
torch.set_default_device("cuda")
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
max_model_len=8192,
|
||||
is_encoder_decoder=False,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
compile_ranges_endpoints=[8],
|
||||
compile_sizes=[16],
|
||||
inductor_compile_config={
|
||||
"post_grad_custom_post_pass": checker,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = TestModel(vllm_config=vllm_config, prefix="").eval()
|
||||
# 3 compilations: Range(1,8), Range(9,8192), single-size 16
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=1,
|
||||
num_backend_compilations=3,
|
||||
):
|
||||
run_model(vllm_config, model, [1, 16, 64])
|
||||
|
||||
# compile_sizes=16 should produce static shapes
|
||||
assert checker.num_static_calls == 1, (
|
||||
f"Expected 1 static compilation, got {checker.num_static_calls}"
|
||||
)
|
||||
# compile_ranges should produce dynamic shapes
|
||||
assert checker.num_dynamic_calls == 2, (
|
||||
f"Expected 2 dynamic compilations, got {checker.num_dynamic_calls}"
|
||||
)
|
||||
|
||||
|
||||
def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
|
||||
# To force multiple compilations, we disable the compile cache
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
@@ -348,13 +348,39 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
# Can remove this after the following issue gets fixed
|
||||
# https://github.com/pytorch/pytorch/issues/174502
|
||||
if envs.VLLM_ENABLE_PREGRAD_PASSES:
|
||||
ctx: Any = contextlib.nullcontext()
|
||||
pregrad_ctx: Any = contextlib.nullcontext()
|
||||
else:
|
||||
ctx = patch(
|
||||
pregrad_ctx = patch(
|
||||
"torch._inductor.compile_fx._recursive_pre_grad_passes",
|
||||
lambda gm, _: gm,
|
||||
)
|
||||
with ctx, _patch_constrain_to_fx_strides():
|
||||
|
||||
# When inputs are FakeTensors (from create_concrete_args),
|
||||
# standalone_compile("from_example_inputs") would normally create
|
||||
# a fresh FakeTensorMode, causing a mode mismatch assertion.
|
||||
# Patch FakeTensorMode in standalone_compile so it reuses the
|
||||
# mode already attached to our FakeTensors. This gives us both
|
||||
# ignore_shape_env=True (from "from_example_inputs") and mode
|
||||
# consistency (from reusing our mode).
|
||||
# Can remove this after the following issue gets fixed:
|
||||
# https://github.com/pytorch/pytorch/issues/176562
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
input_fake_mode = None
|
||||
for x in example_inputs:
|
||||
if isinstance(x, FakeTensor):
|
||||
input_fake_mode = x.fake_mode
|
||||
break
|
||||
|
||||
if input_fake_mode is not None:
|
||||
fake_mode_ctx: Any = patch(
|
||||
"torch._inductor.standalone_compile.FakeTensorMode",
|
||||
lambda *a, **kw: input_fake_mode,
|
||||
)
|
||||
else:
|
||||
fake_mode_ctx = contextlib.nullcontext()
|
||||
|
||||
with pregrad_ctx, fake_mode_ctx, _patch_constrain_to_fx_strides():
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
|
||||
if use_aot:
|
||||
|
||||
@@ -34,13 +34,14 @@ def get_fake_args_from_graph(graph: fx.GraphModule) -> list[Any]:
|
||||
|
||||
|
||||
def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]:
|
||||
"""Create example inputs with symbolic dims replaced by a concrete size.
|
||||
"""Create Fake example inputs with symbolic dims replaced by a concrete size.
|
||||
|
||||
Used for single-size eager compilation where we need concrete-shaped
|
||||
inputs but don't have real runtime tensors yet.
|
||||
Used for single-size compilation where we need concrete-shaped inputs.
|
||||
The Dynamo-captured graph gives us example inputs with SymInts in them.
|
||||
"""
|
||||
from torch._prims_common import compute_required_storage_length
|
||||
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, is_symbolic
|
||||
|
||||
def concretize(sym_val: Any) -> int:
|
||||
"""Replace all symbolic variables in a SymInt expression with size."""
|
||||
@@ -49,25 +50,28 @@ def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]:
|
||||
expr = sym_val.node.expr
|
||||
return int(expr.subs({s: size for s in expr.free_symbols}))
|
||||
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
|
||||
args: list[Any] = []
|
||||
for node in graph.graph.nodes:
|
||||
if node.op != "placeholder":
|
||||
break
|
||||
val = node.meta["example_value"]
|
||||
if isinstance(val, torch.SymInt):
|
||||
args.append(concretize(val))
|
||||
elif isinstance(val, torch.Tensor):
|
||||
new_shape = tuple(concretize(d) for d in val.shape)
|
||||
new_strides = tuple(concretize(s) for s in val.stride())
|
||||
new_storage_offset = concretize(val.storage_offset())
|
||||
needed_size = compute_required_storage_length(
|
||||
new_shape, new_strides, new_storage_offset
|
||||
)
|
||||
t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
|
||||
t = t.as_strided(new_shape, new_strides, new_storage_offset)
|
||||
args.append(t)
|
||||
else:
|
||||
args.append(val)
|
||||
with fake_mode:
|
||||
for node in graph.graph.nodes:
|
||||
if node.op != "placeholder":
|
||||
break
|
||||
val = node.meta["example_value"]
|
||||
if isinstance(val, torch.SymInt):
|
||||
args.append(concretize(val))
|
||||
elif isinstance(val, torch.Tensor):
|
||||
new_shape = tuple(concretize(d) for d in val.shape)
|
||||
new_strides = tuple(concretize(s) for s in val.stride())
|
||||
new_storage_offset = concretize(val.storage_offset())
|
||||
needed_size = compute_required_storage_length(
|
||||
new_shape, new_strides, new_storage_offset
|
||||
)
|
||||
t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
|
||||
t = t.as_strided(new_shape, new_strides, new_storage_offset)
|
||||
args.append(t)
|
||||
else:
|
||||
args.append(val)
|
||||
return args
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user