[torch.compile] Use FakeTensors instead of real GPU tensors for single-size compilation (#36093)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-03-11 12:07:09 -04:00
committed by GitHub
parent bea02cdf93
commit 822e250ab7
3 changed files with 137 additions and 25 deletions

View File

@@ -127,6 +127,88 @@ def test_compile_config_get_compile_ranges():
]
class PostGradStaticShapeChecker(InductorPass):
"""Asserts that compile_sizes entries produce graphs with fully concrete
(non-symbolic) shapes, and compile_ranges entries have symbolic shapes."""
def __init__(self):
self.num_static_calls = 0
self.num_dynamic_calls = 0
def __call__(self, graph: fx.Graph):
from torch.fx.experimental.symbolic_shapes import is_symbolic
compile_range = get_pass_context().compile_range
is_single = compile_range.is_single_size()
for node in graph.nodes:
val = node.meta.get("val")
if val is None:
val = node.meta.get("example_value")
if isinstance(val, torch.Tensor):
has_symbolic = any(is_symbolic(d) for d in val.shape)
if is_single:
assert not has_symbolic, (
f"compile_sizes entry {compile_range}: "
f"node '{node.name}' has symbolic shape "
f"{val.shape}"
)
else:
# compile_ranges should have at least some
# symbolic shapes (the batch dimension)
if has_symbolic:
self.num_dynamic_calls += 1
return
if is_single:
self.num_static_calls += 1
def uuid(self) -> str:
state: dict[str, Any] = {}
return InductorPass.hash_dict(state)
def test_compile_sizes_produce_static_shapes(use_fresh_inductor_cache):
"""Verify that compile_sizes entries are compiled with fully concrete
shapes (no SymInts), while compile_ranges entries retain dynamic shapes."""
checker = PostGradStaticShapeChecker()
torch.set_default_device("cuda")
vllm_config = VllmConfig(
scheduler_config=SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_encoder_decoder=False,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_ranges_endpoints=[8],
compile_sizes=[16],
inductor_compile_config={
"post_grad_custom_post_pass": checker,
},
),
)
with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix="").eval()
# 3 compilations: Range(1,8), Range(9,8192), single-size 16
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=1,
num_backend_compilations=3,
):
run_model(vllm_config, model, [1, 16, 64])
# compile_sizes=16 should produce static shapes
assert checker.num_static_calls == 1, (
f"Expected 1 static compilation, got {checker.num_static_calls}"
)
# compile_ranges should produce dynamic shapes
assert checker.num_dynamic_calls == 2, (
f"Expected 2 dynamic compilations, got {checker.num_dynamic_calls}"
)
def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
# To force multiple compilations, we disable the compile cache
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")