[torch.compile] Stop doing unnecessary FakeTensorProp in PiecewiseCompileInterpreter (#34093)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-02-10 22:15:40 -05:00
committed by GitHub
parent 3bcd494ef4
commit e30cedd44b
2 changed files with 41 additions and 4 deletions

View File

@@ -27,10 +27,29 @@ from ...utils import create_new_process_for_each_test
from ..silly_attention import get_global_counter, reset_global_counter
# Custom op that returns an unbacked symint during graph capture
@torch.library.custom_op("mylib::foo", mutates_args=())
def foo(x: torch.Tensor) -> int:
return 3
@foo.register_fake
def _(x):
return torch.library.get_ctx().new_dynamic_size()
@support_torch_compile
class SillyModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
intermediate_unbacked=False,
**kwargs,
) -> None:
super().__init__()
self.intermediate_unbacked = intermediate_unbacked
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
@@ -44,6 +63,13 @@ class SillyModel(nn.Module):
torch.ops.silly.attention(x, x, x, out)
x = out
x = x - 2
if self.intermediate_unbacked:
# Test for unbacked symints: the following is a fancy way to multiply by 1
u0 = foo(x)
ones = x.new_ones(x.shape[0], u0).sum(-1) / 3
x = x * ones
x = x - 1
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
@@ -52,6 +78,7 @@ class SillyModel(nn.Module):
return x
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
def _run_simple_model(
splitting_ops,
use_inductor_graph_partition,
@@ -60,6 +87,8 @@ def _run_simple_model(
expected_num_piecewise_capturable_graphs_seen,
expected_num_backend_compilations,
expected_num_cudagraph_captured,
*,
intermediate_unbacked=False,
):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
@@ -72,7 +101,11 @@ def _run_simple_model(
)
)
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix="")
model = SillyModel(
vllm_config=vllm_config,
prefix="",
intermediate_unbacked=intermediate_unbacked,
)
inputs = torch.randn(100).cuda()
@@ -125,9 +158,10 @@ def _run_simple_model(
@pytest.mark.parametrize("backend", ["inductor", "eager"])
@pytest.mark.parametrize("intermediate_unbacked", [True, False])
@torch.inference_mode()
@create_new_process_for_each_test("spawn")
def test_simple_piecewise_compile(backend):
def test_simple_piecewise_compile(backend, intermediate_unbacked):
_run_simple_model(
splitting_ops=["silly::attention"],
use_inductor_graph_partition=False,
@@ -140,6 +174,7 @@ def test_simple_piecewise_compile(backend):
expected_num_backend_compilations=3,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=6,
intermediate_unbacked=intermediate_unbacked,
)

View File

@@ -570,7 +570,9 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
) -> Any:
assert isinstance(target, str)
output = super().call_module(target, args, kwargs)
gm = getattr(self.module, target)
outputs = gm.graph.output_node().args[0]
output = fx.map_arg(outputs, lambda node: node.meta["example_value"])
if target in self.compile_submod_names:
index = self.compile_submod_names.index(target)