[ez] Add structured torch.compile logs (#33213)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi
2026-01-31 05:00:54 -08:00
committed by GitHub
parent f0a1c8453a
commit 608b556507
3 changed files with 227 additions and 0 deletions

View File

@@ -19,6 +19,7 @@ from typing import Any
import torch
import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher
from torch._logging._internal import trace_structured
import vllm.envs as envs
from vllm.compilation.inductor_pass import pass_context
@@ -529,6 +530,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
sym_shape_indices,
self.vllm_backend,
graph_returns_tuple(submod),
submod_name=target,
)
self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
@@ -735,12 +737,61 @@ class VllmBackend:
)
self.inductor_config[self.pass_key] = self.pass_manager
def _log_compilation_config(self):
"""Log vLLM compilation config for TORCH_TRACE/tlparse."""
cc = self.compilation_config
pass_cfg = cc.pass_config
# Helper to convert lists to comma-separated strings for tlparse display
def list_to_str(lst: list | None) -> str:
if lst is None:
return ""
return ", ".join(str(x) for x in lst)
# Get enabled passes by introspecting dataclass fields
enabled_passes = [
f.name
for f in dataclasses.fields(pass_cfg)
if isinstance(getattr(pass_cfg, f.name), bool) and getattr(pass_cfg, f.name)
]
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "vllm_compilation_config",
"encoding": "json",
},
payload_fn=lambda: json.dumps(
{
"model": self.vllm_config.model_config.model,
"prefix": self.prefix,
"mode": str(cc.mode),
"backend": cc.backend,
"custom_ops": list_to_str(cc.custom_ops),
"splitting_ops": list_to_str(cc.splitting_ops),
"cudagraph_mode": str(cc.cudagraph_mode),
"compile_sizes": list_to_str(cc.compile_sizes),
"compile_ranges_split_points": list_to_str(
cc.compile_ranges_split_points
),
"use_inductor_graph_partition": cc.use_inductor_graph_partition,
"inductor_passes": list_to_str(list(cc.inductor_passes.keys())),
"enabled_passes": list_to_str(enabled_passes),
"dynamic_shapes_type": str(cc.dynamic_shapes_config.type),
"dynamic_shapes_evaluate_guards": cc.dynamic_shapes_config.evaluate_guards, # noqa: E501
}
),
)
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
from .caching import (
VllmSerializableFunction,
)
vllm_config = self.vllm_config
self._log_compilation_config()
# Minimal hashing here with existing utilities, reused below.
env_factors = envs.compile_factors()
@@ -892,6 +943,13 @@ class VllmBackend:
lazy_format_graph_code("before split", self.graph)
lazy_format_graph_code("after split", self.split_gm)
# Log the piecewise split graph for TORCH_TRACE/tlparse
trace_structured(
"graph_dump",
metadata_fn=lambda: {"name": "vllm_piecewise_split_graph"},
payload_fn=lambda: self.split_gm.print_readable(print_output=False),
)
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
submod_names_to_compile = [
item.submod_name

View File

@@ -3,6 +3,7 @@
import dataclasses
import io
import json
import pickle
from collections.abc import Callable
from pickle import Pickler
@@ -11,6 +12,7 @@ from typing import Any
import torch._functorch.config
import torch.fx as fx
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._logging._internal import trace_structured
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
@@ -39,6 +41,7 @@ class PiecewiseBackend:
vllm_backend: VllmBackend,
returns_tuple: bool,
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
submod_name: str = "",
):
"""
The backend for piecewise compilation.
@@ -70,6 +73,7 @@ class PiecewiseBackend:
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.compiled_runnables = compiled_runnables
self.submod_name = submod_name
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
@@ -131,6 +135,9 @@ class PiecewiseBackend:
compile_range=range,
)
# Track whether we've logged the graph for this subgraph (only log once)
self._graph_logged = False
# get the on_compilation_complete callback from context...
# PiecewiseBackend is created during the first call,
# which is when the context is set (see compilation/decorators.py)
@@ -221,6 +228,45 @@ class PiecewiseBackend:
assert len(fake_example_inputs) == len(args)
return fake_example_inputs
def _log_compile_start(self, compile_range: Range):
"""Log compilation event for TORCH_TRACE/tlparse."""
is_cudagraph_size = (
self.compile_sizes is not None and compile_range.start in self.compile_sizes
)
subgraph_index = self.piecewise_compile_index
submod_name = self.submod_name
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "vllm_piecewise_compile_start",
"encoding": "json",
},
payload_fn=lambda: json.dumps(
{
"piecewise_index": subgraph_index,
"submod_name": submod_name,
"total_piecewise_compiles": self.total_piecewise_compiles,
"compile_range_start": compile_range.start,
"compile_range_end": compile_range.end,
"is_single_size": compile_range.is_single_size(),
"is_cudagraph_capture_size": is_cudagraph_size,
}
),
)
# Log the subgraph graph dump only once per subgraph (not per size)
# to reduce log file size. The graph code is the same for all sizes.
if not self._graph_logged:
self._graph_logged = True
assert self.graph is not None
trace_structured(
"graph_dump",
metadata_fn=lambda: {
"name": f"vllm_{submod_name}",
},
payload_fn=lambda: self.graph.print_readable(print_output=False),
)
def _maybe_compile_for_range_entry(
self, range_entry: RangeEntry, args: tuple[Any, ...]
) -> Any:
@@ -230,6 +276,8 @@ class PiecewiseBackend:
self.compiled_runnables[str(range_entry.compile_range)]
)
else:
self._log_compile_start(range_entry.compile_range)
# args are real arguments
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in