diff --git a/tests/compile/fullgraph/test_simple.py b/tests/compile/fullgraph/test_simple.py index 36cc1510e..ed9c7a351 100644 --- a/tests/compile/fullgraph/test_simple.py +++ b/tests/compile/fullgraph/test_simple.py @@ -27,10 +27,29 @@ from ...utils import create_new_process_for_each_test from ..silly_attention import get_global_counter, reset_global_counter +# Custom op that returns an unbacked symint during graph capture +@torch.library.custom_op("mylib::foo", mutates_args=()) +def foo(x: torch.Tensor) -> int: + return 3 + + +@foo.register_fake +def _(x): + return torch.library.get_ctx().new_dynamic_size() + + @support_torch_compile class SillyModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + intermediate_unbacked=False, + **kwargs, + ) -> None: super().__init__() + self.intermediate_unbacked = intermediate_unbacked def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -44,6 +63,13 @@ class SillyModel(nn.Module): torch.ops.silly.attention(x, x, x, out) x = out x = x - 2 + + if self.intermediate_unbacked: + # Test for unbacked symints: the following is a fancy way to multiply by 1 + u0 = foo(x) + ones = x.new_ones(x.shape[0], u0).sum(-1) / 3 + x = x * ones + x = x - 1 out = torch.empty_like(x) torch.ops.silly.attention(x, x, x, out) @@ -52,6 +78,7 @@ class SillyModel(nn.Module): return x +@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) def _run_simple_model( splitting_ops, use_inductor_graph_partition, @@ -60,6 +87,8 @@ def _run_simple_model( expected_num_piecewise_capturable_graphs_seen, expected_num_backend_compilations, expected_num_cudagraph_captured, + *, + intermediate_unbacked=False, ): vllm_config = VllmConfig( compilation_config=CompilationConfig( @@ -72,7 +101,11 @@ def _run_simple_model( ) ) with set_current_vllm_config(vllm_config): - model = SillyModel(vllm_config=vllm_config, prefix="") + model = SillyModel( + vllm_config=vllm_config, + prefix="", + intermediate_unbacked=intermediate_unbacked, + ) inputs = torch.randn(100).cuda() @@ -125,9 +158,10 @@ def _run_simple_model( @pytest.mark.parametrize("backend", ["inductor", "eager"]) +@pytest.mark.parametrize("intermediate_unbacked", [True, False]) @torch.inference_mode() @create_new_process_for_each_test("spawn") -def test_simple_piecewise_compile(backend): +def test_simple_piecewise_compile(backend, intermediate_unbacked): _run_simple_model( splitting_ops=["silly::attention"], use_inductor_graph_partition=False, @@ -140,6 +174,7 @@ def test_simple_piecewise_compile(backend): expected_num_backend_compilations=3, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen expected_num_cudagraph_captured=6, + intermediate_unbacked=intermediate_unbacked, ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index e5cdb2d33..315bac73f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -570,7 +570,9 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ) -> Any: assert isinstance(target, str) - output = super().call_module(target, args, kwargs) + gm = getattr(self.module, target) + outputs = gm.graph.output_node().args[0] + output = fx.map_arg(outputs, lambda node: node.meta["example_value"]) if target in self.compile_submod_names: index = self.compile_submod_names.index(target)