diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index dee7cdde7..2846193e7 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -31,6 +31,7 @@ from vllm.logging_utils import lazy from vllm.platforms import current_platform from vllm.tracing import instrument, instrument_manual from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import is_torch_equal_or_newer from .compiler_interface import ( CompilerInterface, @@ -575,11 +576,14 @@ def split_graph( # the semantics of the graph will change when we # have mutations in the graph with _use_lazy_graph_module(True): + has_tuple_return = is_torch_equal_or_newer("2.12.0.dev") + tuple_return_kwarg = {"tuple_return": True} if has_tuple_return else {} split_gm = torch.fx.passes.split_module.split_module( graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True, + **tuple_return_kwarg, ) outputs = []