[torch.compile] use depyf to dump torch.compile internals (#10972)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -33,3 +33,4 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need
|
|||||||
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||||
einops # Required for Qwen2-VL.
|
einops # Required for Qwen2-VL.
|
||||||
compressed-tensors == 0.8.0 # required for compressed-tensors
|
compressed-tensors == 0.8.0 # required for compressed-tensors
|
||||||
|
depyf==0.18.0 # required for profiling and debugging torch.compile
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import torch
|
|||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import weak_ref_tensors
|
from vllm.utils import weak_ref_tensors
|
||||||
|
|
||||||
@@ -149,14 +149,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, module: torch.fx.GraphModule,
|
def __init__(self, module: torch.fx.GraphModule,
|
||||||
compile_submod_names: List[str],
|
compile_submod_names: List[str], vllm_config: VllmConfig,
|
||||||
compilation_configs: CompilationConfig, graph_pool):
|
graph_pool):
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
self.fake_mode = detect_fake_mode()
|
self.fake_mode = detect_fake_mode()
|
||||||
self.compile_submod_names = compile_submod_names
|
self.compile_submod_names = compile_submod_names
|
||||||
self.compilation_configs = compilation_configs
|
self.compilation_config = vllm_config.compilation_config
|
||||||
self.graph_pool = graph_pool
|
self.graph_pool = graph_pool
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
|
||||||
def run(self, *args):
|
def run(self, *args):
|
||||||
fake_args = [
|
fake_args = [
|
||||||
@@ -182,15 +183,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
compiled_graph_for_general_shape = wrap_inductor(
|
compiled_graph_for_general_shape = wrap_inductor(
|
||||||
submod,
|
submod,
|
||||||
args,
|
args,
|
||||||
self.compilation_configs.inductor_compile_config,
|
self.compilation_config.inductor_compile_config,
|
||||||
self.compilation_configs,
|
self.compilation_config,
|
||||||
graph_index=index,
|
graph_index=index,
|
||||||
num_graphs=len(self.compile_submod_names),
|
num_graphs=len(self.compile_submod_names),
|
||||||
runtime_shape=None,
|
runtime_shape=None,
|
||||||
use_inductor=self.compilation_configs.use_inductor)
|
use_inductor=self.compilation_config.use_inductor)
|
||||||
|
|
||||||
self.module.__dict__[target] = PiecewiseBackend(
|
self.module.__dict__[target] = PiecewiseBackend(
|
||||||
submod, self.compilation_configs, self.graph_pool, index,
|
submod, self.vllm_config, self.graph_pool, index,
|
||||||
len(self.compile_submod_names), sym_shape_indices,
|
len(self.compile_submod_names), sym_shape_indices,
|
||||||
compiled_graph_for_general_shape)
|
compiled_graph_for_general_shape)
|
||||||
|
|
||||||
@@ -211,7 +212,8 @@ class VllmBackend:
|
|||||||
which handles the post-grad passes.
|
which handles the post-grad passes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
compilation_configs: CompilationConfig
|
vllm_config: VllmConfig
|
||||||
|
compilation_config: CompilationConfig
|
||||||
graph_pool: Any
|
graph_pool: Any
|
||||||
_called: bool = False
|
_called: bool = False
|
||||||
# the graph we compiled
|
# the graph we compiled
|
||||||
@@ -227,7 +229,7 @@ class VllmBackend:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
compilation_configs: CompilationConfig,
|
vllm_config: VllmConfig,
|
||||||
):
|
):
|
||||||
global global_graph_pool
|
global global_graph_pool
|
||||||
if global_graph_pool is None:
|
if global_graph_pool is None:
|
||||||
@@ -244,13 +246,14 @@ class VllmBackend:
|
|||||||
self.sym_tensor_indices = []
|
self.sym_tensor_indices = []
|
||||||
self.input_buffers = []
|
self.input_buffers = []
|
||||||
|
|
||||||
self.compilation_configs = compilation_configs
|
self.vllm_config = vllm_config
|
||||||
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
|
||||||
# `torch.compile` is JIT compiled, so we don't need to
|
# `torch.compile` is JIT compiled, so we don't need to
|
||||||
# do anything here
|
# do anything here
|
||||||
|
|
||||||
def configure_post_pass(self):
|
def configure_post_pass(self):
|
||||||
config = self.compilation_configs
|
config = self.compilation_config
|
||||||
self.post_grad_pass_manager.configure(config.pass_config)
|
self.post_grad_pass_manager.configure(config.pass_config)
|
||||||
|
|
||||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||||
@@ -271,7 +274,7 @@ class VllmBackend:
|
|||||||
from .monitor import torch_compile_start_time
|
from .monitor import torch_compile_start_time
|
||||||
dynamo_time = time.time() - torch_compile_start_time
|
dynamo_time = time.time() - torch_compile_start_time
|
||||||
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
|
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
|
||||||
self.compilation_configs.compilation_time += dynamo_time
|
self.compilation_config.compilation_time += dynamo_time
|
||||||
|
|
||||||
# we control the compilation process, each instance can only be
|
# we control the compilation process, each instance can only be
|
||||||
# called once
|
# called once
|
||||||
@@ -281,7 +284,7 @@ class VllmBackend:
|
|||||||
self.configure_post_pass()
|
self.configure_post_pass()
|
||||||
|
|
||||||
self.split_gm, self.piecewise_graphs = split_graph(
|
self.split_gm, self.piecewise_graphs = split_graph(
|
||||||
graph, self.compilation_configs.splitting_ops)
|
graph, self.compilation_config.splitting_ops)
|
||||||
|
|
||||||
from torch._dynamo.utils import lazy_format_graph_code
|
from torch._dynamo.utils import lazy_format_graph_code
|
||||||
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
|
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
|
||||||
@@ -298,13 +301,13 @@ class VllmBackend:
|
|||||||
# propagate the split graph to the piecewise backend,
|
# propagate the split graph to the piecewise backend,
|
||||||
# compile submodules with symbolic shapes
|
# compile submodules with symbolic shapes
|
||||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||||
self.compilation_configs,
|
self.vllm_config,
|
||||||
self.graph_pool).run(*example_inputs)
|
self.graph_pool).run(*example_inputs)
|
||||||
|
|
||||||
self._called = True
|
self._called = True
|
||||||
|
|
||||||
if not self.compilation_configs.use_cudagraph or \
|
if not self.compilation_config.use_cudagraph or \
|
||||||
not self.compilation_configs.cudagraph_copy_inputs:
|
not self.compilation_config.cudagraph_copy_inputs:
|
||||||
return self.split_gm
|
return self.split_gm
|
||||||
|
|
||||||
# if we need to copy input buffers for cudagraph
|
# if we need to copy input buffers for cudagraph
|
||||||
@@ -364,10 +367,9 @@ class ConcreteSizeEntry:
|
|||||||
|
|
||||||
class PiecewiseBackend:
|
class PiecewiseBackend:
|
||||||
|
|
||||||
def __init__(self, graph: fx.GraphModule,
|
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||||
compilation_configs: CompilationConfig, graph_pool: Any,
|
graph_pool: Any, piecewise_compile_index: int,
|
||||||
piecewise_compile_index: int, total_piecewise_compiles: int,
|
total_piecewise_compiles: int, sym_shape_indices: List[int],
|
||||||
sym_shape_indices: List[int],
|
|
||||||
compiled_graph_for_general_shape: Callable):
|
compiled_graph_for_general_shape: Callable):
|
||||||
"""
|
"""
|
||||||
The backend for piecewise compilation.
|
The backend for piecewise compilation.
|
||||||
@@ -375,7 +377,7 @@ class PiecewiseBackend:
|
|||||||
|
|
||||||
We will compile `self.graph` once for the general shape,
|
We will compile `self.graph` once for the general shape,
|
||||||
and then compile for different shapes specified in
|
and then compile for different shapes specified in
|
||||||
`compilation_configs.compile_sizes`.
|
`compilation_config.compile_sizes`.
|
||||||
|
|
||||||
Independently, we will capture cudagraph for different shapes.
|
Independently, we will capture cudagraph for different shapes.
|
||||||
|
|
||||||
@@ -383,7 +385,8 @@ class PiecewiseBackend:
|
|||||||
compile it first, and then capture cudagraph.
|
compile it first, and then capture cudagraph.
|
||||||
"""
|
"""
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.compilation_configs = compilation_configs
|
self.vllm_config = vllm_config
|
||||||
|
self.compilation_config = vllm_config.compilation_config
|
||||||
self.graph_pool = graph_pool
|
self.graph_pool = graph_pool
|
||||||
self.piecewise_compile_index = piecewise_compile_index
|
self.piecewise_compile_index = piecewise_compile_index
|
||||||
self.total_piecewise_compiles = total_piecewise_compiles
|
self.total_piecewise_compiles = total_piecewise_compiles
|
||||||
@@ -393,10 +396,10 @@ class PiecewiseBackend:
|
|||||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||||
|
|
||||||
self.compile_sizes: Set[int] = set(
|
self.compile_sizes: Set[int] = set(
|
||||||
self.compilation_configs.compile_sizes)
|
self.compilation_config.compile_sizes)
|
||||||
self.capture_sizes: Set[int] = set(
|
self.capture_sizes: Set[int] = set(
|
||||||
self.compilation_configs.capture_sizes
|
self.compilation_config.capture_sizes
|
||||||
) if self.compilation_configs.use_cudagraph else set()
|
) if self.compilation_config.use_cudagraph else set()
|
||||||
|
|
||||||
self.first_run_finished = False
|
self.first_run_finished = False
|
||||||
|
|
||||||
@@ -423,7 +426,7 @@ class PiecewiseBackend:
|
|||||||
self.first_run_finished = True
|
self.first_run_finished = True
|
||||||
# no specific sizes to compile
|
# no specific sizes to compile
|
||||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||||
end_monitoring_torch_compile(self.compilation_configs)
|
end_monitoring_torch_compile(self.vllm_config)
|
||||||
return self.compiled_graph_for_general_shape(*args)
|
return self.compiled_graph_for_general_shape(*args)
|
||||||
|
|
||||||
runtime_shape = args[self.sym_shape_indices[0]]
|
runtime_shape = args[self.sym_shape_indices[0]]
|
||||||
@@ -443,28 +446,28 @@ class PiecewiseBackend:
|
|||||||
entry.runnable = wrap_inductor(
|
entry.runnable = wrap_inductor(
|
||||||
self.graph,
|
self.graph,
|
||||||
args,
|
args,
|
||||||
self.compilation_configs.inductor_compile_config,
|
self.compilation_config.inductor_compile_config,
|
||||||
self.compilation_configs,
|
self.compilation_config,
|
||||||
graph_index=self.piecewise_compile_index,
|
graph_index=self.piecewise_compile_index,
|
||||||
num_graphs=self.total_piecewise_compiles,
|
num_graphs=self.total_piecewise_compiles,
|
||||||
runtime_shape=runtime_shape,
|
runtime_shape=runtime_shape,
|
||||||
use_inductor=self.compilation_configs.use_inductor)
|
use_inductor=self.compilation_config.use_inductor)
|
||||||
|
|
||||||
# finished compilations for all required shapes
|
# finished compilations for all required shapes
|
||||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||||
end_monitoring_torch_compile(self.compilation_configs)
|
end_monitoring_torch_compile(self.vllm_config)
|
||||||
|
|
||||||
if not entry.use_cudagraph:
|
if not entry.use_cudagraph:
|
||||||
return entry.runnable(*args)
|
return entry.runnable(*args)
|
||||||
|
|
||||||
if entry.cudagraph is None:
|
if entry.cudagraph is None:
|
||||||
if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa
|
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
|
||||||
entry.num_finished_warmup += 1
|
entry.num_finished_warmup += 1
|
||||||
if self.is_first_graph:
|
if self.is_first_graph:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Warming up %s/%s for shape %s",
|
"Warming up %s/%s for shape %s",
|
||||||
entry.num_finished_warmup,
|
entry.num_finished_warmup,
|
||||||
self.compilation_configs.cudagraph_num_of_warmups,
|
self.compilation_config.cudagraph_num_of_warmups,
|
||||||
runtime_shape)
|
runtime_shape)
|
||||||
return entry.runnable(*args)
|
return entry.runnable(*args)
|
||||||
|
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ def _support_torch_compile(
|
|||||||
"Unsupported dynamic dimensions"
|
"Unsupported dynamic dimensions"
|
||||||
f" {dims} for argument {k} with type {type(arg)}.")
|
f" {dims} for argument {k} with type {type(arg)}.")
|
||||||
# here, it is the starting point of the `torch.compile` process
|
# here, it is the starting point of the `torch.compile` process
|
||||||
start_monitoring_torch_compile(self.vllm_config.compilation_config)
|
start_monitoring_torch_compile(self.vllm_config)
|
||||||
|
|
||||||
# if we don't use custom dispatcher, we can directly call the
|
# if we don't use custom dispatcher, we can directly call the
|
||||||
# compiled function and let torch.compile handle the dispatching,
|
# compiled function and let torch.compile handle the dispatching,
|
||||||
|
|||||||
@@ -1,19 +1,36 @@
|
|||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from vllm.config import CompilationConfig, CompilationLevel
|
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
context_manager = None
|
||||||
torch_compile_start_time: float = 0.0
|
torch_compile_start_time: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
def start_monitoring_torch_compile(compilation_config: CompilationConfig):
|
def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||||
global torch_compile_start_time
|
global torch_compile_start_time
|
||||||
torch_compile_start_time = time.time()
|
torch_compile_start_time = time.time()
|
||||||
|
|
||||||
|
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||||
|
if compilation_config.level == CompilationLevel.PIECEWISE and \
|
||||||
|
compilation_config.debug_dump_path:
|
||||||
|
import depyf
|
||||||
|
path = os.path.join(compilation_config.debug_dump_path,
|
||||||
|
f"rank_{vllm_config.parallel_config.rank}")
|
||||||
|
global context_manager
|
||||||
|
context_manager = depyf.prepare_debug(path)
|
||||||
|
context_manager.__enter__()
|
||||||
|
|
||||||
def end_monitoring_torch_compile(compilation_config: CompilationConfig):
|
|
||||||
|
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||||
|
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||||
logger.info("torch.compile takes %.2f s in total",
|
logger.info("torch.compile takes %.2f s in total",
|
||||||
compilation_config.compilation_time)
|
compilation_config.compilation_time)
|
||||||
|
global context_manager
|
||||||
|
if context_manager is not None:
|
||||||
|
context_manager.__exit__(None, None, None)
|
||||||
|
context_manager = None
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
# default compilation settings
|
# default compilation settings
|
||||||
# compiling the forward method
|
# compiling the forward method
|
||||||
|
|
||||||
backend = get_current_vllm_config(
|
vllm_config = get_current_vllm_config()
|
||||||
).compilation_config.init_backend()
|
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||||
|
|
||||||
compiled_callable = torch.compile(
|
compiled_callable = torch.compile(
|
||||||
self.forward,
|
self.forward,
|
||||||
|
|||||||
@@ -2222,6 +2222,7 @@ class CompilationConfig(BaseModel):
|
|||||||
- 1: dynamo as is.
|
- 1: dynamo as is.
|
||||||
- 2: dynamo once.
|
- 2: dynamo once.
|
||||||
- 3: piecewise compilation.
|
- 3: piecewise compilation.
|
||||||
|
- debug_dump_path: the path to dump the debug information.
|
||||||
- backend: the backend for compilation. It needs to be a string.
|
- backend: the backend for compilation. It needs to be a string.
|
||||||
- "" (empty string): use the default backend.
|
- "" (empty string): use the default backend.
|
||||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||||
@@ -2289,6 +2290,7 @@ class CompilationConfig(BaseModel):
|
|||||||
certain small batchsizes, where inductor is good at optimizing.
|
certain small batchsizes, where inductor is good at optimizing.
|
||||||
""" # noqa
|
""" # noqa
|
||||||
level: int = 0
|
level: int = 0
|
||||||
|
debug_dump_path: str = ""
|
||||||
backend: str = ""
|
backend: str = ""
|
||||||
custom_ops: List[str] = Field(default_factory=list)
|
custom_ops: List[str] = Field(default_factory=list)
|
||||||
splitting_ops: List[str] = Field(default_factory=lambda: [
|
splitting_ops: List[str] = Field(default_factory=lambda: [
|
||||||
@@ -2394,7 +2396,7 @@ class CompilationConfig(BaseModel):
|
|||||||
self.static_forward_context = {}
|
self.static_forward_context = {}
|
||||||
self.compilation_time = 0.0
|
self.compilation_time = 0.0
|
||||||
|
|
||||||
def init_backend(self) -> Union[str, Callable]:
|
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||||
if self.level == CompilationLevel.NO_COMPILATION:
|
if self.level == CompilationLevel.NO_COMPILATION:
|
||||||
raise ValueError("No compilation level is set.")
|
raise ValueError("No compilation level is set.")
|
||||||
|
|
||||||
@@ -2413,7 +2415,7 @@ class CompilationConfig(BaseModel):
|
|||||||
# merge with the config use_inductor
|
# merge with the config use_inductor
|
||||||
assert self.level == CompilationLevel.PIECEWISE
|
assert self.level == CompilationLevel.PIECEWISE
|
||||||
from vllm.compilation.backends import VllmBackend
|
from vllm.compilation.backends import VllmBackend
|
||||||
return VllmBackend(self)
|
return VllmBackend(vllm_config)
|
||||||
|
|
||||||
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
|
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
|
||||||
"""To complete the initialization of config,
|
"""To complete the initialization of config,
|
||||||
|
|||||||
@@ -1162,7 +1162,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
if self.vllm_config.compilation_config.level ==\
|
if self.vllm_config.compilation_config.level ==\
|
||||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||||
backend = self.vllm_config.compilation_config.init_backend()
|
backend = self.vllm_config.compilation_config.init_backend(
|
||||||
|
self.vllm_config)
|
||||||
self.model = torch.compile(
|
self.model = torch.compile(
|
||||||
self.model,
|
self.model,
|
||||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
|
|||||||
Reference in New Issue
Block a user