[torch.compile] CUDAGraph Inductor partition integration (#24281)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Signed-off-by: boyuanfeng <boyuan@meta.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Boyuan Feng
2025-09-19 18:02:15 -07:00
committed by GitHub
parent b8a287a0a8
commit 8945b001db
9 changed files with 280 additions and 32 deletions

View File

@@ -299,6 +299,26 @@ class CompilationConfig:
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
"""
use_inductor_graph_partition: bool = False
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
This partition happens at inductor codegen time after all passes and fusions
are finished. It generates a single `call` function which wraps
cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops
outside the partition functions. For a graph with N cudagraph-unsafe ops
(e.g., Attention), there would be N+1 partitions. To mark an op as
cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when
register the custom op.
This config supports both full cudagraph and piecewise cudagraph without
compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper
to each partition. For N+1 partitions, there would be N+1
CUDAGraph wrapper instances.
For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the
inductor `call` function in the model runner. The top-level full cudagraph
capture ignores all partitioning.
"""
pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""
@@ -461,6 +481,12 @@ class CompilationConfig:
"since full_cuda_graph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.FULL
if (self.use_inductor_graph_partition
and not is_torch_equal_or_newer("2.9.0.dev")):
raise ValueError("use_inductor_graph_partition is only "
"supported with torch>=2.9.0.dev. Set "
"use_inductor_graph_partition=False instead.")
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@@ -540,19 +566,36 @@ class CompilationConfig:
"set_splitting_ops_for_v1 should only be called when "
"level is CompilationLevel.PIECEWISE")
use_inductor_graph_partition_msg = (
"When use_inductor_graph_partition=True, splitting_ops "
"are ignored and set to an empty list. Instead, "
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
"used to annotate custom ops for graph partition.")
if self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture the
# full cudagraph outside the fx graph. This reduces some cpu
# overhead when the runtime batch_size is not cudagraph captured.
# see https://github.com/vllm-project/vllm/pull/20059 for details.
# make a copy to avoid mutating the class-level list via reference.
self.splitting_ops = list(self._attention_ops)
if self.use_inductor_graph_partition:
# When using inductor graph partition, we set splitting_ops
# to be empty and rely on torch._C.Tag.cudagraph_unsafe to
# annotate custom ops as splitting ops.
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
else:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture
# the full cudagraph outside the fx graph. This reduces some
# cpu overhead when the runtime batch_size is not cudagraph
# captured. see https://github.com/vllm-project/vllm/pull/20059
# for details. make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty "
"splitting_ops.")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"Using piecewise compilation with empty "
"splitting_ops and use_inductor_graph_partition"
f"={self.use_inductor_graph_partition}.")
if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE
and not self.use_inductor_graph_partition):
logger.warning_once(
"When compilation level is piecewise with empty "
"splitting_ops, PIECEWISE cudagraph_mode will be "
@@ -562,7 +605,26 @@ class CompilationConfig:
"any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
elif self.use_inductor_graph_partition:
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all(
op in self.splitting_ops for op in self._attention_ops)
def is_attention_compiled_piecewise(self) -> bool:
use_fx_graph_piecewise_compilation = (
self.level == CompilationLevel.PIECEWISE
and self.splitting_ops_contain_attention())
inductor_used = (self.level == CompilationLevel.PIECEWISE
and self.use_inductor) or (
self.level >= CompilationLevel.DYNAMO_AS_IS
and self.backend == "inductor")
use_inductor_piecewise_compilation = (
inductor_used and self.use_inductor_graph_partition
and not self.splitting_ops_contain_attention())
return use_fx_graph_piecewise_compilation or \
use_inductor_piecewise_compilation