[torch.compile] use depyf to dump torch.compile internals (#10972)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-11 10:43:05 -08:00
committed by GitHub
parent fd22220687
commit 91642db952
7 changed files with 66 additions and 42 deletions

View File

@@ -33,3 +33,4 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL. einops # Required for Qwen2-VL.
compressed-tensors == 0.8.0 # required for compressed-tensors compressed-tensors == 0.8.0 # required for compressed-tensors
depyf==0.18.0 # required for profiling and debugging torch.compile

View File

@@ -9,7 +9,7 @@ import torch
import torch.fx as fx import torch.fx as fx
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CompilationConfig from vllm.config import CompilationConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors from vllm.utils import weak_ref_tensors
@@ -149,14 +149,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
""" """
def __init__(self, module: torch.fx.GraphModule, def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: List[str], compile_submod_names: List[str], vllm_config: VllmConfig,
compilation_configs: CompilationConfig, graph_pool): graph_pool):
super().__init__(module) super().__init__(module)
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode() self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names self.compile_submod_names = compile_submod_names
self.compilation_configs = compilation_configs self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool self.graph_pool = graph_pool
self.vllm_config = vllm_config
def run(self, *args): def run(self, *args):
fake_args = [ fake_args = [
@@ -182,15 +183,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
compiled_graph_for_general_shape = wrap_inductor( compiled_graph_for_general_shape = wrap_inductor(
submod, submod,
args, args,
self.compilation_configs.inductor_compile_config, self.compilation_config.inductor_compile_config,
self.compilation_configs, self.compilation_config,
graph_index=index, graph_index=index,
num_graphs=len(self.compile_submod_names), num_graphs=len(self.compile_submod_names),
runtime_shape=None, runtime_shape=None,
use_inductor=self.compilation_configs.use_inductor) use_inductor=self.compilation_config.use_inductor)
self.module.__dict__[target] = PiecewiseBackend( self.module.__dict__[target] = PiecewiseBackend(
submod, self.compilation_configs, self.graph_pool, index, submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices, len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape) compiled_graph_for_general_shape)
@@ -211,7 +212,8 @@ class VllmBackend:
which handles the post-grad passes. which handles the post-grad passes.
""" """
compilation_configs: CompilationConfig vllm_config: VllmConfig
compilation_config: CompilationConfig
graph_pool: Any graph_pool: Any
_called: bool = False _called: bool = False
# the graph we compiled # the graph we compiled
@@ -227,7 +229,7 @@ class VllmBackend:
def __init__( def __init__(
self, self,
compilation_configs: CompilationConfig, vllm_config: VllmConfig,
): ):
global global_graph_pool global global_graph_pool
if global_graph_pool is None: if global_graph_pool is None:
@@ -244,13 +246,14 @@ class VllmBackend:
self.sym_tensor_indices = [] self.sym_tensor_indices = []
self.input_buffers = [] self.input_buffers = []
self.compilation_configs = compilation_configs self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
# `torch.compile` is JIT compiled, so we don't need to # `torch.compile` is JIT compiled, so we don't need to
# do anything here # do anything here
def configure_post_pass(self): def configure_post_pass(self):
config = self.compilation_configs config = self.compilation_config
self.post_grad_pass_manager.configure(config.pass_config) self.post_grad_pass_manager.configure(config.pass_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass # Post-grad custom passes are run using the post_grad_custom_post_pass
@@ -271,7 +274,7 @@ class VllmBackend:
from .monitor import torch_compile_start_time from .monitor import torch_compile_start_time
dynamo_time = time.time() - torch_compile_start_time dynamo_time = time.time() - torch_compile_start_time
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
self.compilation_configs.compilation_time += dynamo_time self.compilation_config.compilation_time += dynamo_time
# we control the compilation process, each instance can only be # we control the compilation process, each instance can only be
# called once # called once
@@ -281,7 +284,7 @@ class VllmBackend:
self.configure_post_pass() self.configure_post_pass()
self.split_gm, self.piecewise_graphs = split_graph( self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.splitting_ops) graph, self.compilation_config.splitting_ops)
from torch._dynamo.utils import lazy_format_graph_code from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s", lazy_format_graph_code("before split", self.graph)) logger.debug("%s", lazy_format_graph_code("before split", self.graph))
@@ -298,13 +301,13 @@ class VllmBackend:
# propagate the split graph to the piecewise backend, # propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes # compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
self.compilation_configs, self.vllm_config,
self.graph_pool).run(*example_inputs) self.graph_pool).run(*example_inputs)
self._called = True self._called = True
if not self.compilation_configs.use_cudagraph or \ if not self.compilation_config.use_cudagraph or \
not self.compilation_configs.cudagraph_copy_inputs: not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm return self.split_gm
# if we need to copy input buffers for cudagraph # if we need to copy input buffers for cudagraph
@@ -364,10 +367,9 @@ class ConcreteSizeEntry:
class PiecewiseBackend: class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule, def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
compilation_configs: CompilationConfig, graph_pool: Any, graph_pool: Any, piecewise_compile_index: int,
piecewise_compile_index: int, total_piecewise_compiles: int, total_piecewise_compiles: int, sym_shape_indices: List[int],
sym_shape_indices: List[int],
compiled_graph_for_general_shape: Callable): compiled_graph_for_general_shape: Callable):
""" """
The backend for piecewise compilation. The backend for piecewise compilation.
@@ -375,7 +377,7 @@ class PiecewiseBackend:
We will compile `self.graph` once for the general shape, We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in and then compile for different shapes specified in
`compilation_configs.compile_sizes`. `compilation_config.compile_sizes`.
Independently, we will capture cudagraph for different shapes. Independently, we will capture cudagraph for different shapes.
@@ -383,7 +385,8 @@ class PiecewiseBackend:
compile it first, and then capture cudagraph. compile it first, and then capture cudagraph.
""" """
self.graph = graph self.graph = graph
self.compilation_configs = compilation_configs self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles self.total_piecewise_compiles = total_piecewise_compiles
@@ -393,10 +396,10 @@ class PiecewiseBackend:
piecewise_compile_index == total_piecewise_compiles - 1) piecewise_compile_index == total_piecewise_compiles - 1)
self.compile_sizes: Set[int] = set( self.compile_sizes: Set[int] = set(
self.compilation_configs.compile_sizes) self.compilation_config.compile_sizes)
self.capture_sizes: Set[int] = set( self.capture_sizes: Set[int] = set(
self.compilation_configs.capture_sizes self.compilation_config.capture_sizes
) if self.compilation_configs.use_cudagraph else set() ) if self.compilation_config.use_cudagraph else set()
self.first_run_finished = False self.first_run_finished = False
@@ -423,7 +426,7 @@ class PiecewiseBackend:
self.first_run_finished = True self.first_run_finished = True
# no specific sizes to compile # no specific sizes to compile
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.compilation_configs) end_monitoring_torch_compile(self.vllm_config)
return self.compiled_graph_for_general_shape(*args) return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]] runtime_shape = args[self.sym_shape_indices[0]]
@@ -443,28 +446,28 @@ class PiecewiseBackend:
entry.runnable = wrap_inductor( entry.runnable = wrap_inductor(
self.graph, self.graph,
args, args,
self.compilation_configs.inductor_compile_config, self.compilation_config.inductor_compile_config,
self.compilation_configs, self.compilation_config,
graph_index=self.piecewise_compile_index, graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles, num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape, runtime_shape=runtime_shape,
use_inductor=self.compilation_configs.use_inductor) use_inductor=self.compilation_config.use_inductor)
# finished compilations for all required shapes # finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.compilation_configs) end_monitoring_torch_compile(self.vllm_config)
if not entry.use_cudagraph: if not entry.use_cudagraph:
return entry.runnable(*args) return entry.runnable(*args)
if entry.cudagraph is None: if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1 entry.num_finished_warmup += 1
if self.is_first_graph: if self.is_first_graph:
logger.debug( logger.debug(
"Warming up %s/%s for shape %s", "Warming up %s/%s for shape %s",
entry.num_finished_warmup, entry.num_finished_warmup,
self.compilation_configs.cudagraph_num_of_warmups, self.compilation_config.cudagraph_num_of_warmups,
runtime_shape) runtime_shape)
return entry.runnable(*args) return entry.runnable(*args)

View File

@@ -185,7 +185,7 @@ def _support_torch_compile(
"Unsupported dynamic dimensions" "Unsupported dynamic dimensions"
f" {dims} for argument {k} with type {type(arg)}.") f" {dims} for argument {k} with type {type(arg)}.")
# here, it is the starting point of the `torch.compile` process # here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config.compilation_config) start_monitoring_torch_compile(self.vllm_config)
# if we don't use custom dispatcher, we can directly call the # if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching, # compiled function and let torch.compile handle the dispatching,

View File

@@ -1,19 +1,36 @@
import os
import time import time
from vllm.config import CompilationConfig, CompilationLevel from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
context_manager = None
torch_compile_start_time: float = 0.0 torch_compile_start_time: float = 0.0
def start_monitoring_torch_compile(compilation_config: CompilationConfig): def start_monitoring_torch_compile(vllm_config: VllmConfig):
global torch_compile_start_time global torch_compile_start_time
torch_compile_start_time = time.time() torch_compile_start_time = time.time()
compilation_config: CompilationConfig = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.PIECEWISE and \
compilation_config.debug_dump_path:
import depyf
path = os.path.join(compilation_config.debug_dump_path,
f"rank_{vllm_config.parallel_config.rank}")
global context_manager
context_manager = depyf.prepare_debug(path)
context_manager.__enter__()
def end_monitoring_torch_compile(compilation_config: CompilationConfig):
def end_monitoring_torch_compile(vllm_config: VllmConfig):
compilation_config: CompilationConfig = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.PIECEWISE: if compilation_config.level == CompilationLevel.PIECEWISE:
logger.info("torch.compile takes %.2f s in total", logger.info("torch.compile takes %.2f s in total",
compilation_config.compilation_time) compilation_config.compilation_time)
global context_manager
if context_manager is not None:
context_manager.__exit__(None, None, None)
context_manager = None

View File

@@ -32,8 +32,8 @@ class TorchCompileWrapperWithCustomDispatcher:
# default compilation settings # default compilation settings
# compiling the forward method # compiling the forward method
backend = get_current_vllm_config( vllm_config = get_current_vllm_config()
).compilation_config.init_backend() backend = vllm_config.compilation_config.init_backend(vllm_config)
compiled_callable = torch.compile( compiled_callable = torch.compile(
self.forward, self.forward,

View File

@@ -2222,6 +2222,7 @@ class CompilationConfig(BaseModel):
- 1: dynamo as is. - 1: dynamo as is.
- 2: dynamo once. - 2: dynamo once.
- 3: piecewise compilation. - 3: piecewise compilation.
- debug_dump_path: the path to dump the debug information.
- backend: the backend for compilation. It needs to be a string. - backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend. - "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch. - "eager"/"openxla"/...: use the specified backend registered in PyTorch.
@@ -2289,6 +2290,7 @@ class CompilationConfig(BaseModel):
certain small batchsizes, where inductor is good at optimizing. certain small batchsizes, where inductor is good at optimizing.
""" # noqa """ # noqa
level: int = 0 level: int = 0
debug_dump_path: str = ""
backend: str = "" backend: str = ""
custom_ops: List[str] = Field(default_factory=list) custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [ splitting_ops: List[str] = Field(default_factory=lambda: [
@@ -2394,7 +2396,7 @@ class CompilationConfig(BaseModel):
self.static_forward_context = {} self.static_forward_context = {}
self.compilation_time = 0.0 self.compilation_time = 0.0
def init_backend(self) -> Union[str, Callable]: def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION: if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.") raise ValueError("No compilation level is set.")
@@ -2413,7 +2415,7 @@ class CompilationConfig(BaseModel):
# merge with the config use_inductor # merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE assert self.level == CompilationLevel.PIECEWISE
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
return VllmBackend(self) return VllmBackend(vllm_config)
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
"""To complete the initialization of config, """To complete the initialization of config,

View File

@@ -1162,7 +1162,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if self.vllm_config.compilation_config.level ==\ if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
backend = self.vllm_config.compilation_config.init_backend() backend = self.vllm_config.compilation_config.init_backend(
self.vllm_config)
self.model = torch.compile( self.model = torch.compile(
self.model, self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,