diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 286ed4a8b..9fd8e9577 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -127,6 +127,88 @@ def test_compile_config_get_compile_ranges(): ] +class PostGradStaticShapeChecker(InductorPass): + """Asserts that compile_sizes entries produce graphs with fully concrete + (non-symbolic) shapes, and compile_ranges entries have symbolic shapes.""" + + def __init__(self): + self.num_static_calls = 0 + self.num_dynamic_calls = 0 + + def __call__(self, graph: fx.Graph): + from torch.fx.experimental.symbolic_shapes import is_symbolic + + compile_range = get_pass_context().compile_range + is_single = compile_range.is_single_size() + + for node in graph.nodes: + val = node.meta.get("val") + if val is None: + val = node.meta.get("example_value") + if isinstance(val, torch.Tensor): + has_symbolic = any(is_symbolic(d) for d in val.shape) + if is_single: + assert not has_symbolic, ( + f"compile_sizes entry {compile_range}: " + f"node '{node.name}' has symbolic shape " + f"{val.shape}" + ) + else: + # compile_ranges should have at least some + # symbolic shapes (the batch dimension) + if has_symbolic: + self.num_dynamic_calls += 1 + return + + if is_single: + self.num_static_calls += 1 + + def uuid(self) -> str: + state: dict[str, Any] = {} + return InductorPass.hash_dict(state) + + +def test_compile_sizes_produce_static_shapes(use_fresh_inductor_cache): + """Verify that compile_sizes entries are compiled with fully concrete + shapes (no SymInts), while compile_ranges entries retain dynamic shapes.""" + checker = PostGradStaticShapeChecker() + torch.set_default_device("cuda") + vllm_config = VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + max_model_len=8192, + is_encoder_decoder=False, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_ranges_endpoints=[8], + compile_sizes=[16], + inductor_compile_config={ + "post_grad_custom_post_pass": checker, + }, + ), + ) + + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix="").eval() + # 3 compilations: Range(1,8), Range(9,8192), single-size 16 + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=3, + ): + run_model(vllm_config, model, [1, 16, 64]) + + # compile_sizes=16 should produce static shapes + assert checker.num_static_calls == 1, ( + f"Expected 1 static compilation, got {checker.num_static_calls}" + ) + # compile_ranges should produce dynamic shapes + assert checker.num_dynamic_calls == 2, ( + f"Expected 2 dynamic compilations, got {checker.num_dynamic_calls}" + ) + + def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache): # To force multiple compilations, we disable the compile cache monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 035370063..2242f0304 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -348,13 +348,39 @@ class InductorStandaloneAdaptor(CompilerInterface): # Can remove this after the following issue gets fixed # https://github.com/pytorch/pytorch/issues/174502 if envs.VLLM_ENABLE_PREGRAD_PASSES: - ctx: Any = contextlib.nullcontext() + pregrad_ctx: Any = contextlib.nullcontext() else: - ctx = patch( + pregrad_ctx = patch( "torch._inductor.compile_fx._recursive_pre_grad_passes", lambda gm, _: gm, ) - with ctx, _patch_constrain_to_fx_strides(): + + # When inputs are FakeTensors (from create_concrete_args), + # standalone_compile("from_example_inputs") would normally create + # a fresh FakeTensorMode, causing a mode mismatch assertion. + # Patch FakeTensorMode in standalone_compile so it reuses the + # mode already attached to our FakeTensors. This gives us both + # ignore_shape_env=True (from "from_example_inputs") and mode + # consistency (from reusing our mode). + # Can remove this after the following issue gets fixed: + # https://github.com/pytorch/pytorch/issues/176562 + from torch._subclasses.fake_tensor import FakeTensor + + input_fake_mode = None + for x in example_inputs: + if isinstance(x, FakeTensor): + input_fake_mode = x.fake_mode + break + + if input_fake_mode is not None: + fake_mode_ctx: Any = patch( + "torch._inductor.standalone_compile.FakeTensorMode", + lambda *a, **kw: input_fake_mode, + ) + else: + fake_mode_ctx = contextlib.nullcontext() + + with pregrad_ctx, fake_mode_ctx, _patch_constrain_to_fx_strides(): compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs) if use_aot: diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 5aeb51a7a..7474d0bf8 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -34,13 +34,14 @@ def get_fake_args_from_graph(graph: fx.GraphModule) -> list[Any]: def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]: - """Create example inputs with symbolic dims replaced by a concrete size. + """Create Fake example inputs with symbolic dims replaced by a concrete size. - Used for single-size eager compilation where we need concrete-shaped - inputs but don't have real runtime tensors yet. + Used for single-size compilation where we need concrete-shaped inputs. + The Dynamo-captured graph gives us example inputs with SymInts in them. """ from torch._prims_common import compute_required_storage_length - from torch.fx.experimental.symbolic_shapes import is_symbolic + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.symbolic_shapes import ShapeEnv, is_symbolic def concretize(sym_val: Any) -> int: """Replace all symbolic variables in a SymInt expression with size.""" @@ -49,25 +50,28 @@ def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]: expr = sym_val.node.expr return int(expr.subs({s: size for s in expr.free_symbols})) + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + args: list[Any] = [] - for node in graph.graph.nodes: - if node.op != "placeholder": - break - val = node.meta["example_value"] - if isinstance(val, torch.SymInt): - args.append(concretize(val)) - elif isinstance(val, torch.Tensor): - new_shape = tuple(concretize(d) for d in val.shape) - new_strides = tuple(concretize(s) for s in val.stride()) - new_storage_offset = concretize(val.storage_offset()) - needed_size = compute_required_storage_length( - new_shape, new_strides, new_storage_offset - ) - t = torch.empty(needed_size, dtype=val.dtype, device=val.device) - t = t.as_strided(new_shape, new_strides, new_storage_offset) - args.append(t) - else: - args.append(val) + with fake_mode: + for node in graph.graph.nodes: + if node.op != "placeholder": + break + val = node.meta["example_value"] + if isinstance(val, torch.SymInt): + args.append(concretize(val)) + elif isinstance(val, torch.Tensor): + new_shape = tuple(concretize(d) for d in val.shape) + new_strides = tuple(concretize(s) for s in val.stride()) + new_storage_offset = concretize(val.storage_offset()) + needed_size = compute_required_storage_length( + new_shape, new_strides, new_storage_offset + ) + t = torch.empty(needed_size, dtype=val.dtype, device=val.device) + t = t.as_strided(new_shape, new_strides, new_storage_offset) + args.append(t) + else: + args.append(val) return args