[torch.compile] Stop doing unnecessary FakeTensorProp in PiecewiseCompileInterpreter (#34093)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -27,10 +27,29 @@ from ...utils import create_new_process_for_each_test
|
||||
from ..silly_attention import get_global_counter, reset_global_counter
|
||||
|
||||
|
||||
# Custom op that returns an unbacked symint during graph capture
|
||||
@torch.library.custom_op("mylib::foo", mutates_args=())
|
||||
def foo(x: torch.Tensor) -> int:
|
||||
return 3
|
||||
|
||||
|
||||
@foo.register_fake
|
||||
def _(x):
|
||||
return torch.library.get_ctx().new_dynamic_size()
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class SillyModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
intermediate_unbacked=False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.intermediate_unbacked = intermediate_unbacked
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@@ -44,6 +63,13 @@ class SillyModel(nn.Module):
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x - 2
|
||||
|
||||
if self.intermediate_unbacked:
|
||||
# Test for unbacked symints: the following is a fancy way to multiply by 1
|
||||
u0 = foo(x)
|
||||
ones = x.new_ones(x.shape[0], u0).sum(-1) / 3
|
||||
x = x * ones
|
||||
|
||||
x = x - 1
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
@@ -52,6 +78,7 @@ class SillyModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
def _run_simple_model(
|
||||
splitting_ops,
|
||||
use_inductor_graph_partition,
|
||||
@@ -60,6 +87,8 @@ def _run_simple_model(
|
||||
expected_num_piecewise_capturable_graphs_seen,
|
||||
expected_num_backend_compilations,
|
||||
expected_num_cudagraph_captured,
|
||||
*,
|
||||
intermediate_unbacked=False,
|
||||
):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
@@ -72,7 +101,11 @@ def _run_simple_model(
|
||||
)
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SillyModel(vllm_config=vllm_config, prefix="")
|
||||
model = SillyModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix="",
|
||||
intermediate_unbacked=intermediate_unbacked,
|
||||
)
|
||||
|
||||
inputs = torch.randn(100).cuda()
|
||||
|
||||
@@ -125,9 +158,10 @@ def _run_simple_model(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["inductor", "eager"])
|
||||
@pytest.mark.parametrize("intermediate_unbacked", [True, False])
|
||||
@torch.inference_mode()
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_simple_piecewise_compile(backend):
|
||||
def test_simple_piecewise_compile(backend, intermediate_unbacked):
|
||||
_run_simple_model(
|
||||
splitting_ops=["silly::attention"],
|
||||
use_inductor_graph_partition=False,
|
||||
@@ -140,6 +174,7 @@ def test_simple_piecewise_compile(backend):
|
||||
expected_num_backend_compilations=3,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
expected_num_cudagraph_captured=6,
|
||||
intermediate_unbacked=intermediate_unbacked,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user