[torch.compile] Use FakeTensors instead of real GPU tensors for single-size compilation (#36093)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -127,6 +127,88 @@ def test_compile_config_get_compile_ranges():
|
||||
]
|
||||
|
||||
|
||||
class PostGradStaticShapeChecker(InductorPass):
|
||||
"""Asserts that compile_sizes entries produce graphs with fully concrete
|
||||
(non-symbolic) shapes, and compile_ranges entries have symbolic shapes."""
|
||||
|
||||
def __init__(self):
|
||||
self.num_static_calls = 0
|
||||
self.num_dynamic_calls = 0
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
||||
|
||||
compile_range = get_pass_context().compile_range
|
||||
is_single = compile_range.is_single_size()
|
||||
|
||||
for node in graph.nodes:
|
||||
val = node.meta.get("val")
|
||||
if val is None:
|
||||
val = node.meta.get("example_value")
|
||||
if isinstance(val, torch.Tensor):
|
||||
has_symbolic = any(is_symbolic(d) for d in val.shape)
|
||||
if is_single:
|
||||
assert not has_symbolic, (
|
||||
f"compile_sizes entry {compile_range}: "
|
||||
f"node '{node.name}' has symbolic shape "
|
||||
f"{val.shape}"
|
||||
)
|
||||
else:
|
||||
# compile_ranges should have at least some
|
||||
# symbolic shapes (the batch dimension)
|
||||
if has_symbolic:
|
||||
self.num_dynamic_calls += 1
|
||||
return
|
||||
|
||||
if is_single:
|
||||
self.num_static_calls += 1
|
||||
|
||||
def uuid(self) -> str:
|
||||
state: dict[str, Any] = {}
|
||||
return InductorPass.hash_dict(state)
|
||||
|
||||
|
||||
def test_compile_sizes_produce_static_shapes(use_fresh_inductor_cache):
|
||||
"""Verify that compile_sizes entries are compiled with fully concrete
|
||||
shapes (no SymInts), while compile_ranges entries retain dynamic shapes."""
|
||||
checker = PostGradStaticShapeChecker()
|
||||
torch.set_default_device("cuda")
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
max_model_len=8192,
|
||||
is_encoder_decoder=False,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
compile_ranges_endpoints=[8],
|
||||
compile_sizes=[16],
|
||||
inductor_compile_config={
|
||||
"post_grad_custom_post_pass": checker,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = TestModel(vllm_config=vllm_config, prefix="").eval()
|
||||
# 3 compilations: Range(1,8), Range(9,8192), single-size 16
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=1,
|
||||
num_backend_compilations=3,
|
||||
):
|
||||
run_model(vllm_config, model, [1, 16, 64])
|
||||
|
||||
# compile_sizes=16 should produce static shapes
|
||||
assert checker.num_static_calls == 1, (
|
||||
f"Expected 1 static compilation, got {checker.num_static_calls}"
|
||||
)
|
||||
# compile_ranges should produce dynamic shapes
|
||||
assert checker.num_dynamic_calls == 2, (
|
||||
f"Expected 2 dynamic compilations, got {checker.num_dynamic_calls}"
|
||||
)
|
||||
|
||||
|
||||
def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
|
||||
# To force multiple compilations, we disable the compile cache
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
Reference in New Issue
Block a user