[aot_compile]change VLLM backend to read fake args from example_value (#29104)

Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
Laith Sakka
2025-12-04 14:33:45 -08:00
committed by GitHub
parent c8ab988b15
commit 1f0d184590
3 changed files with 81 additions and 10 deletions

View File

@@ -402,6 +402,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.extra_traceback = False
def run(self, *args):
# maybe instead just assert inputs are fake?
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
@@ -416,11 +417,13 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
kwargs: dict[str, Any],
) -> Any:
assert isinstance(target, str)
output = super().call_module(target, args, kwargs)
if target in self.compile_submod_names:
index = self.compile_submod_names.index(target)
submod = self.fetch_attr(target)
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
@@ -746,11 +749,21 @@ class VllmBackend:
if not item.is_splitting_graph
]
# Extract fake values from the graph to use them when needed.
all_fake_values = []
for i in graph.graph.find_nodes(op="placeholder"):
all_fake_values.append(i.meta["example_value"])
fake_args = [
all_fake_values[i] if isinstance(t, torch.Tensor) else t
for i, t in enumerate(example_inputs)
]
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(
self.split_gm, submod_names_to_compile, self.vllm_config, self
).run(*example_inputs)
).run(*fake_args)
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
if not os.path.exists(graph_path):
@@ -780,14 +793,7 @@ class VllmBackend:
)
# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
fake_args = [
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in example_inputs
]
#
# index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.

View File

@@ -433,7 +433,6 @@ def _support_torch_compile(
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
_mark_dynamic_inputs(
self,