[ez] Add structured torch.compile logs (#33213)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi
2026-01-31 05:00:54 -08:00
committed by GitHub
parent f0a1c8453a
commit 608b556507
3 changed files with 227 additions and 0 deletions

View File

@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch
import pytest
import regex as re
import torch
from torch import nn
import tests.compile.silly_attention # noqa
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.compilation import (
CompilationConfig,
CompilationMode,
CUDAGraphMode,
)
from vllm.config.scheduler import SchedulerConfig
from vllm.forward_context import set_forward_context
MLP_SIZE = 64
@support_torch_compile
class SimpleModel(nn.Module):
"""A simple model with a splitting op for piecewise compilation."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output * 2
return x
class TraceStructuredCapture:
"""Captures trace_structured calls for testing."""
def __init__(self):
self.calls: list[dict] = []
def __call__(self, event_type: str, metadata_fn=None, payload_fn=None, **kwargs):
"""Capture a trace_structured call."""
metadata = metadata_fn() if metadata_fn else {}
self.calls.append(
{
"event_type": event_type,
"metadata": metadata,
}
)
def get(self, event_type: str, name_pattern: str) -> list[dict]:
"""Get all calls with the given event type and name matching pattern.
Args:
event_type: The event type to filter by (e.g., "artifact", "graph_dump")
name_pattern: Regex pattern to match against the artifact name
"""
regex = re.compile(name_pattern)
return [
c
for c in self.calls
if c["event_type"] == event_type
and regex.fullmatch(c.get("metadata", {}).get("name", ""))
]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache):
"""Test that all expected vLLM artifacts are logged during compilation."""
torch.set_default_device("cuda")
capture = TraceStructuredCapture()
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[8],
splitting_ops=["silly::attention"],
),
scheduler_config=SchedulerConfig(
max_num_seqs=8,
max_model_len=8192,
is_encoder_decoder=False,
),
)
# Patch trace_structured to capture calls
with (
patch("vllm.compilation.backends.trace_structured", capture),
patch("vllm.compilation.piecewise_backend.trace_structured", capture),
set_current_vllm_config(vllm_config),
):
model = SimpleModel(vllm_config=vllm_config, prefix="test")
with set_forward_context({}, vllm_config=vllm_config):
model(torch.randn(8, MLP_SIZE))
config_artifacts = capture.get("artifact", "vllm_compilation_config")
assert len(config_artifacts) == 1, (
f"Expected 1 vllm_compilation_config, got {len(config_artifacts)}"
)
vllm_piecewise_split_graph = capture.get("graph_dump", "vllm_piecewise_split_graph")
assert len(vllm_piecewise_split_graph) == 1, (
"Expected 1 toplevel piecewise split graph, "
f"got {len(vllm_piecewise_split_graph)}"
)
compile_start_artifacts = capture.get("artifact", "vllm_piecewise_compile_start")
assert len(compile_start_artifacts) == 2, (
"Expected 2 vllm_piecewise_compile_start "
"(one for dynamic ranges, one for compile size), "
f"got {len(compile_start_artifacts)}"
)
submod_dumps = capture.get("graph_dump", r"vllm_submod_.*")
assert len(submod_dumps) == 2, (
"Expected 2 submods (one before attention, one after attention), "
f"got {len(submod_dumps)}"
)

View File

@@ -19,6 +19,7 @@ from typing import Any
import torch
import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher
from torch._logging._internal import trace_structured
import vllm.envs as envs
from vllm.compilation.inductor_pass import pass_context
@@ -529,6 +530,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
sym_shape_indices,
self.vllm_backend,
graph_returns_tuple(submod),
submod_name=target,
)
self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
@@ -735,12 +737,61 @@ class VllmBackend:
)
self.inductor_config[self.pass_key] = self.pass_manager
def _log_compilation_config(self):
"""Log vLLM compilation config for TORCH_TRACE/tlparse."""
cc = self.compilation_config
pass_cfg = cc.pass_config
# Helper to convert lists to comma-separated strings for tlparse display
def list_to_str(lst: list | None) -> str:
if lst is None:
return ""
return ", ".join(str(x) for x in lst)
# Get enabled passes by introspecting dataclass fields
enabled_passes = [
f.name
for f in dataclasses.fields(pass_cfg)
if isinstance(getattr(pass_cfg, f.name), bool) and getattr(pass_cfg, f.name)
]
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "vllm_compilation_config",
"encoding": "json",
},
payload_fn=lambda: json.dumps(
{
"model": self.vllm_config.model_config.model,
"prefix": self.prefix,
"mode": str(cc.mode),
"backend": cc.backend,
"custom_ops": list_to_str(cc.custom_ops),
"splitting_ops": list_to_str(cc.splitting_ops),
"cudagraph_mode": str(cc.cudagraph_mode),
"compile_sizes": list_to_str(cc.compile_sizes),
"compile_ranges_split_points": list_to_str(
cc.compile_ranges_split_points
),
"use_inductor_graph_partition": cc.use_inductor_graph_partition,
"inductor_passes": list_to_str(list(cc.inductor_passes.keys())),
"enabled_passes": list_to_str(enabled_passes),
"dynamic_shapes_type": str(cc.dynamic_shapes_config.type),
"dynamic_shapes_evaluate_guards": cc.dynamic_shapes_config.evaluate_guards, # noqa: E501
}
),
)
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
from .caching import (
VllmSerializableFunction,
)
vllm_config = self.vllm_config
self._log_compilation_config()
# Minimal hashing here with existing utilities, reused below.
env_factors = envs.compile_factors()
@@ -892,6 +943,13 @@ class VllmBackend:
lazy_format_graph_code("before split", self.graph)
lazy_format_graph_code("after split", self.split_gm)
# Log the piecewise split graph for TORCH_TRACE/tlparse
trace_structured(
"graph_dump",
metadata_fn=lambda: {"name": "vllm_piecewise_split_graph"},
payload_fn=lambda: self.split_gm.print_readable(print_output=False),
)
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
submod_names_to_compile = [
item.submod_name

View File

@@ -3,6 +3,7 @@
import dataclasses
import io
import json
import pickle
from collections.abc import Callable
from pickle import Pickler
@@ -11,6 +12,7 @@ from typing import Any
import torch._functorch.config
import torch.fx as fx
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._logging._internal import trace_structured
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
@@ -39,6 +41,7 @@ class PiecewiseBackend:
vllm_backend: VllmBackend,
returns_tuple: bool,
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
submod_name: str = "",
):
"""
The backend for piecewise compilation.
@@ -70,6 +73,7 @@ class PiecewiseBackend:
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.compiled_runnables = compiled_runnables
self.submod_name = submod_name
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
@@ -131,6 +135,9 @@ class PiecewiseBackend:
compile_range=range,
)
# Track whether we've logged the graph for this subgraph (only log once)
self._graph_logged = False
# get the on_compilation_complete callback from context...
# PiecewiseBackend is created during the first call,
# which is when the context is set (see compilation/decorators.py)
@@ -221,6 +228,45 @@ class PiecewiseBackend:
assert len(fake_example_inputs) == len(args)
return fake_example_inputs
def _log_compile_start(self, compile_range: Range):
"""Log compilation event for TORCH_TRACE/tlparse."""
is_cudagraph_size = (
self.compile_sizes is not None and compile_range.start in self.compile_sizes
)
subgraph_index = self.piecewise_compile_index
submod_name = self.submod_name
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "vllm_piecewise_compile_start",
"encoding": "json",
},
payload_fn=lambda: json.dumps(
{
"piecewise_index": subgraph_index,
"submod_name": submod_name,
"total_piecewise_compiles": self.total_piecewise_compiles,
"compile_range_start": compile_range.start,
"compile_range_end": compile_range.end,
"is_single_size": compile_range.is_single_size(),
"is_cudagraph_capture_size": is_cudagraph_size,
}
),
)
# Log the subgraph graph dump only once per subgraph (not per size)
# to reduce log file size. The graph code is the same for all sizes.
if not self._graph_logged:
self._graph_logged = True
assert self.graph is not None
trace_structured(
"graph_dump",
metadata_fn=lambda: {
"name": f"vllm_{submod_name}",
},
payload_fn=lambda: self.graph.print_readable(print_output=False),
)
def _maybe_compile_for_range_entry(
self, range_entry: RangeEntry, args: tuple[Any, ...]
) -> Any:
@@ -230,6 +276,8 @@ class PiecewiseBackend:
self.compiled_runnables[str(range_entry.compile_range)]
)
else:
self._log_compile_start(range_entry.compile_range)
# args are real arguments
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in