[torch.compile] CUDAGraph Inductor partition integration (#24281)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -299,6 +299,26 @@ class CompilationConfig:
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
||||
"""
|
||||
|
||||
use_inductor_graph_partition: bool = False
|
||||
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
|
||||
This partition happens at inductor codegen time after all passes and fusions
|
||||
are finished. It generates a single `call` function which wraps
|
||||
cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops
|
||||
outside the partition functions. For a graph with N cudagraph-unsafe ops
|
||||
(e.g., Attention), there would be N+1 partitions. To mark an op as
|
||||
cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when
|
||||
register the custom op.
|
||||
|
||||
This config supports both full cudagraph and piecewise cudagraph without
|
||||
compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper
|
||||
to each partition. For N+1 partitions, there would be N+1
|
||||
CUDAGraph wrapper instances.
|
||||
|
||||
For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the
|
||||
inductor `call` function in the model runner. The top-level full cudagraph
|
||||
capture ignores all partitioning.
|
||||
"""
|
||||
|
||||
pass_config: PassConfig = field(default_factory=PassConfig)
|
||||
"""Custom inductor passes, see PassConfig for more details"""
|
||||
|
||||
@@ -461,6 +481,12 @@ class CompilationConfig:
|
||||
"since full_cuda_graph is deprecated.")
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
if (self.use_inductor_graph_partition
|
||||
and not is_torch_equal_or_newer("2.9.0.dev")):
|
||||
raise ValueError("use_inductor_graph_partition is only "
|
||||
"supported with torch>=2.9.0.dev. Set "
|
||||
"use_inductor_graph_partition=False instead.")
|
||||
|
||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
raise ValueError("No compilation level is set.")
|
||||
@@ -540,19 +566,36 @@ class CompilationConfig:
|
||||
"set_splitting_ops_for_v1 should only be called when "
|
||||
"level is CompilationLevel.PIECEWISE")
|
||||
|
||||
use_inductor_graph_partition_msg = (
|
||||
"When use_inductor_graph_partition=True, splitting_ops "
|
||||
"are ignored and set to an empty list. Instead, "
|
||||
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
|
||||
"used to annotate custom ops for graph partition.")
|
||||
|
||||
if self.splitting_ops is None:
|
||||
# NOTE: When using full cudagraph, instead of setting an empty
|
||||
# list and capture the full cudagraph inside the flattened fx
|
||||
# graph, we keep the piecewise fx graph structure but capture the
|
||||
# full cudagraph outside the fx graph. This reduces some cpu
|
||||
# overhead when the runtime batch_size is not cudagraph captured.
|
||||
# see https://github.com/vllm-project/vllm/pull/20059 for details.
|
||||
# make a copy to avoid mutating the class-level list via reference.
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
if self.use_inductor_graph_partition:
|
||||
# When using inductor graph partition, we set splitting_ops
|
||||
# to be empty and rely on torch._C.Tag.cudagraph_unsafe to
|
||||
# annotate custom ops as splitting ops.
|
||||
logger.warning_once(use_inductor_graph_partition_msg)
|
||||
self.splitting_ops = []
|
||||
else:
|
||||
# NOTE: When using full cudagraph, instead of setting an empty
|
||||
# list and capture the full cudagraph inside the flattened fx
|
||||
# graph, we keep the piecewise fx graph structure but capture
|
||||
# the full cudagraph outside the fx graph. This reduces some
|
||||
# cpu overhead when the runtime batch_size is not cudagraph
|
||||
# captured. see https://github.com/vllm-project/vllm/pull/20059
|
||||
# for details. make a copy to avoid mutating the class-level
|
||||
# list via reference.
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
elif len(self.splitting_ops) == 0:
|
||||
logger.warning_once("Using piecewise compilation with empty "
|
||||
"splitting_ops.")
|
||||
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
logger.warning_once(
|
||||
"Using piecewise compilation with empty "
|
||||
"splitting_ops and use_inductor_graph_partition"
|
||||
f"={self.use_inductor_graph_partition}.")
|
||||
if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
and not self.use_inductor_graph_partition):
|
||||
logger.warning_once(
|
||||
"When compilation level is piecewise with empty "
|
||||
"splitting_ops, PIECEWISE cudagraph_mode will be "
|
||||
@@ -562,7 +605,26 @@ class CompilationConfig:
|
||||
"any problems.")
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
self.splitting_ops = []
|
||||
elif self.use_inductor_graph_partition:
|
||||
logger.warning_once(use_inductor_graph_partition_msg)
|
||||
self.splitting_ops = []
|
||||
|
||||
def splitting_ops_contain_attention(self) -> bool:
|
||||
return self.splitting_ops is not None and all(
|
||||
op in self.splitting_ops for op in self._attention_ops)
|
||||
|
||||
def is_attention_compiled_piecewise(self) -> bool:
|
||||
use_fx_graph_piecewise_compilation = (
|
||||
self.level == CompilationLevel.PIECEWISE
|
||||
and self.splitting_ops_contain_attention())
|
||||
|
||||
inductor_used = (self.level == CompilationLevel.PIECEWISE
|
||||
and self.use_inductor) or (
|
||||
self.level >= CompilationLevel.DYNAMO_AS_IS
|
||||
and self.backend == "inductor")
|
||||
use_inductor_piecewise_compilation = (
|
||||
inductor_used and self.use_inductor_graph_partition
|
||||
and not self.splitting_ops_contain_attention())
|
||||
|
||||
return use_fx_graph_piecewise_compilation or \
|
||||
use_inductor_piecewise_compilation
|
||||
|
||||
Reference in New Issue
Block a user