[aot_compile]change VLLM backend to read fake args from example_value (#29104)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
@@ -402,6 +402,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
self.extra_traceback = False
|
||||
|
||||
def run(self, *args):
|
||||
# maybe instead just assert inputs are fake?
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in args
|
||||
@@ -416,11 +417,13 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
assert isinstance(target, str)
|
||||
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
if target in self.compile_submod_names:
|
||||
index = self.compile_submod_names.index(target)
|
||||
submod = self.fetch_attr(target)
|
||||
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
@@ -746,11 +749,21 @@ class VllmBackend:
|
||||
if not item.is_splitting_graph
|
||||
]
|
||||
|
||||
# Extract fake values from the graph to use them when needed.
|
||||
all_fake_values = []
|
||||
for i in graph.graph.find_nodes(op="placeholder"):
|
||||
all_fake_values.append(i.meta["example_value"])
|
||||
|
||||
fake_args = [
|
||||
all_fake_values[i] if isinstance(t, torch.Tensor) else t
|
||||
for i, t in enumerate(example_inputs)
|
||||
]
|
||||
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(
|
||||
self.split_gm, submod_names_to_compile, self.vllm_config, self
|
||||
).run(*example_inputs)
|
||||
).run(*fake_args)
|
||||
|
||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||
if not os.path.exists(graph_path):
|
||||
@@ -780,14 +793,7 @@ class VllmBackend:
|
||||
)
|
||||
|
||||
# if we need to copy input buffers for cudagraph
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
fake_mode = detect_fake_mode()
|
||||
fake_args = [
|
||||
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in example_inputs
|
||||
]
|
||||
|
||||
#
|
||||
# index of tensors that have symbolic shapes (batch size)
|
||||
# for weights and static buffers, they will have concrete shapes.
|
||||
# symbolic shape only happens for input tensors.
|
||||
|
||||
@@ -433,7 +433,6 @@ def _support_torch_compile(
|
||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||
|
||||
# This is the path for the first compilation.
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
_mark_dynamic_inputs(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user