[ez] Add structured torch.compile logs (#33213)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from typing import Any
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._logging._internal import trace_structured
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import pass_context
|
||||
@@ -529,6 +530,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
|
||||
sym_shape_indices,
|
||||
self.vllm_backend,
|
||||
graph_returns_tuple(submod),
|
||||
submod_name=target,
|
||||
)
|
||||
|
||||
self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
|
||||
@@ -735,12 +737,61 @@ class VllmBackend:
|
||||
)
|
||||
self.inductor_config[self.pass_key] = self.pass_manager
|
||||
|
||||
def _log_compilation_config(self):
|
||||
"""Log vLLM compilation config for TORCH_TRACE/tlparse."""
|
||||
cc = self.compilation_config
|
||||
pass_cfg = cc.pass_config
|
||||
|
||||
# Helper to convert lists to comma-separated strings for tlparse display
|
||||
def list_to_str(lst: list | None) -> str:
|
||||
if lst is None:
|
||||
return ""
|
||||
return ", ".join(str(x) for x in lst)
|
||||
|
||||
# Get enabled passes by introspecting dataclass fields
|
||||
enabled_passes = [
|
||||
f.name
|
||||
for f in dataclasses.fields(pass_cfg)
|
||||
if isinstance(getattr(pass_cfg, f.name), bool) and getattr(pass_cfg, f.name)
|
||||
]
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "vllm_compilation_config",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(
|
||||
{
|
||||
"model": self.vllm_config.model_config.model,
|
||||
"prefix": self.prefix,
|
||||
"mode": str(cc.mode),
|
||||
"backend": cc.backend,
|
||||
"custom_ops": list_to_str(cc.custom_ops),
|
||||
"splitting_ops": list_to_str(cc.splitting_ops),
|
||||
"cudagraph_mode": str(cc.cudagraph_mode),
|
||||
"compile_sizes": list_to_str(cc.compile_sizes),
|
||||
"compile_ranges_split_points": list_to_str(
|
||||
cc.compile_ranges_split_points
|
||||
),
|
||||
"use_inductor_graph_partition": cc.use_inductor_graph_partition,
|
||||
"inductor_passes": list_to_str(list(cc.inductor_passes.keys())),
|
||||
"enabled_passes": list_to_str(enabled_passes),
|
||||
"dynamic_shapes_type": str(cc.dynamic_shapes_config.type),
|
||||
"dynamic_shapes_evaluate_guards": cc.dynamic_shapes_config.evaluate_guards, # noqa: E501
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
|
||||
from .caching import (
|
||||
VllmSerializableFunction,
|
||||
)
|
||||
|
||||
vllm_config = self.vllm_config
|
||||
|
||||
self._log_compilation_config()
|
||||
|
||||
# Minimal hashing here with existing utilities, reused below.
|
||||
|
||||
env_factors = envs.compile_factors()
|
||||
@@ -892,6 +943,13 @@ class VllmBackend:
|
||||
lazy_format_graph_code("before split", self.graph)
|
||||
lazy_format_graph_code("after split", self.split_gm)
|
||||
|
||||
# Log the piecewise split graph for TORCH_TRACE/tlparse
|
||||
trace_structured(
|
||||
"graph_dump",
|
||||
metadata_fn=lambda: {"name": "vllm_piecewise_split_graph"},
|
||||
payload_fn=lambda: self.split_gm.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
|
||||
submod_names_to_compile = [
|
||||
item.submod_name
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import pickle
|
||||
from collections.abc import Callable
|
||||
from pickle import Pickler
|
||||
@@ -11,6 +12,7 @@ from typing import Any
|
||||
import torch._functorch.config
|
||||
import torch.fx as fx
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
from torch._logging._internal import trace_structured
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
@@ -39,6 +41,7 @@ class PiecewiseBackend:
|
||||
vllm_backend: VllmBackend,
|
||||
returns_tuple: bool,
|
||||
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
|
||||
submod_name: str = "",
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
@@ -70,6 +73,7 @@ class PiecewiseBackend:
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
self.compiled_runnables = compiled_runnables
|
||||
self.submod_name = submod_name
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
@@ -131,6 +135,9 @@ class PiecewiseBackend:
|
||||
compile_range=range,
|
||||
)
|
||||
|
||||
# Track whether we've logged the graph for this subgraph (only log once)
|
||||
self._graph_logged = False
|
||||
|
||||
# get the on_compilation_complete callback from context...
|
||||
# PiecewiseBackend is created during the first call,
|
||||
# which is when the context is set (see compilation/decorators.py)
|
||||
@@ -221,6 +228,45 @@ class PiecewiseBackend:
|
||||
assert len(fake_example_inputs) == len(args)
|
||||
return fake_example_inputs
|
||||
|
||||
def _log_compile_start(self, compile_range: Range):
|
||||
"""Log compilation event for TORCH_TRACE/tlparse."""
|
||||
is_cudagraph_size = (
|
||||
self.compile_sizes is not None and compile_range.start in self.compile_sizes
|
||||
)
|
||||
subgraph_index = self.piecewise_compile_index
|
||||
submod_name = self.submod_name
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "vllm_piecewise_compile_start",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(
|
||||
{
|
||||
"piecewise_index": subgraph_index,
|
||||
"submod_name": submod_name,
|
||||
"total_piecewise_compiles": self.total_piecewise_compiles,
|
||||
"compile_range_start": compile_range.start,
|
||||
"compile_range_end": compile_range.end,
|
||||
"is_single_size": compile_range.is_single_size(),
|
||||
"is_cudagraph_capture_size": is_cudagraph_size,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Log the subgraph graph dump only once per subgraph (not per size)
|
||||
# to reduce log file size. The graph code is the same for all sizes.
|
||||
if not self._graph_logged:
|
||||
self._graph_logged = True
|
||||
assert self.graph is not None
|
||||
trace_structured(
|
||||
"graph_dump",
|
||||
metadata_fn=lambda: {
|
||||
"name": f"vllm_{submod_name}",
|
||||
},
|
||||
payload_fn=lambda: self.graph.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
def _maybe_compile_for_range_entry(
|
||||
self, range_entry: RangeEntry, args: tuple[Any, ...]
|
||||
) -> Any:
|
||||
@@ -230,6 +276,8 @@ class PiecewiseBackend:
|
||||
self.compiled_runnables[str(range_entry.compile_range)]
|
||||
)
|
||||
else:
|
||||
self._log_compile_start(range_entry.compile_range)
|
||||
|
||||
# args are real arguments
|
||||
# fakify for range, real args for concrete size.
|
||||
# For concrete size, we clear the shape env in
|
||||
|
||||
Reference in New Issue
Block a user