[ez] Add structured torch.compile logs (#33213)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
121
tests/compile/test_structured_logging.py
Normal file
121
tests/compile/test_structured_logging.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import tests.compile.silly_attention # noqa
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.config.compilation import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
)
|
||||
from vllm.config.scheduler import SchedulerConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
|
||||
MLP_SIZE = 64
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class SimpleModel(nn.Module):
|
||||
"""A simple model with a splitting op for piecewise compilation."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output * 2
|
||||
return x
|
||||
|
||||
|
||||
class TraceStructuredCapture:
|
||||
"""Captures trace_structured calls for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def __call__(self, event_type: str, metadata_fn=None, payload_fn=None, **kwargs):
|
||||
"""Capture a trace_structured call."""
|
||||
metadata = metadata_fn() if metadata_fn else {}
|
||||
self.calls.append(
|
||||
{
|
||||
"event_type": event_type,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
def get(self, event_type: str, name_pattern: str) -> list[dict]:
|
||||
"""Get all calls with the given event type and name matching pattern.
|
||||
|
||||
Args:
|
||||
event_type: The event type to filter by (e.g., "artifact", "graph_dump")
|
||||
name_pattern: Regex pattern to match against the artifact name
|
||||
"""
|
||||
regex = re.compile(name_pattern)
|
||||
return [
|
||||
c
|
||||
for c in self.calls
|
||||
if c["event_type"] == event_type
|
||||
and regex.fullmatch(c.get("metadata", {}).get("name", ""))
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
|
||||
def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache):
|
||||
"""Test that all expected vLLM artifacts are logged during compilation."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
capture = TraceStructuredCapture()
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
compile_sizes=[8],
|
||||
splitting_ops=["silly::attention"],
|
||||
),
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_seqs=8,
|
||||
max_model_len=8192,
|
||||
is_encoder_decoder=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Patch trace_structured to capture calls
|
||||
with (
|
||||
patch("vllm.compilation.backends.trace_structured", capture),
|
||||
patch("vllm.compilation.piecewise_backend.trace_structured", capture),
|
||||
set_current_vllm_config(vllm_config),
|
||||
):
|
||||
model = SimpleModel(vllm_config=vllm_config, prefix="test")
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
model(torch.randn(8, MLP_SIZE))
|
||||
|
||||
config_artifacts = capture.get("artifact", "vllm_compilation_config")
|
||||
assert len(config_artifacts) == 1, (
|
||||
f"Expected 1 vllm_compilation_config, got {len(config_artifacts)}"
|
||||
)
|
||||
vllm_piecewise_split_graph = capture.get("graph_dump", "vllm_piecewise_split_graph")
|
||||
assert len(vllm_piecewise_split_graph) == 1, (
|
||||
"Expected 1 toplevel piecewise split graph, "
|
||||
f"got {len(vllm_piecewise_split_graph)}"
|
||||
)
|
||||
compile_start_artifacts = capture.get("artifact", "vllm_piecewise_compile_start")
|
||||
assert len(compile_start_artifacts) == 2, (
|
||||
"Expected 2 vllm_piecewise_compile_start "
|
||||
"(one for dynamic ranges, one for compile size), "
|
||||
f"got {len(compile_start_artifacts)}"
|
||||
)
|
||||
submod_dumps = capture.get("graph_dump", r"vllm_submod_.*")
|
||||
assert len(submod_dumps) == 2, (
|
||||
"Expected 2 submods (one before attention, one after attention), "
|
||||
f"got {len(submod_dumps)}"
|
||||
)
|
||||
@@ -19,6 +19,7 @@ from typing import Any
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._logging._internal import trace_structured
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import pass_context
|
||||
@@ -529,6 +530,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
|
||||
sym_shape_indices,
|
||||
self.vllm_backend,
|
||||
graph_returns_tuple(submod),
|
||||
submod_name=target,
|
||||
)
|
||||
|
||||
self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
|
||||
@@ -735,12 +737,61 @@ class VllmBackend:
|
||||
)
|
||||
self.inductor_config[self.pass_key] = self.pass_manager
|
||||
|
||||
def _log_compilation_config(self):
|
||||
"""Log vLLM compilation config for TORCH_TRACE/tlparse."""
|
||||
cc = self.compilation_config
|
||||
pass_cfg = cc.pass_config
|
||||
|
||||
# Helper to convert lists to comma-separated strings for tlparse display
|
||||
def list_to_str(lst: list | None) -> str:
|
||||
if lst is None:
|
||||
return ""
|
||||
return ", ".join(str(x) for x in lst)
|
||||
|
||||
# Get enabled passes by introspecting dataclass fields
|
||||
enabled_passes = [
|
||||
f.name
|
||||
for f in dataclasses.fields(pass_cfg)
|
||||
if isinstance(getattr(pass_cfg, f.name), bool) and getattr(pass_cfg, f.name)
|
||||
]
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "vllm_compilation_config",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(
|
||||
{
|
||||
"model": self.vllm_config.model_config.model,
|
||||
"prefix": self.prefix,
|
||||
"mode": str(cc.mode),
|
||||
"backend": cc.backend,
|
||||
"custom_ops": list_to_str(cc.custom_ops),
|
||||
"splitting_ops": list_to_str(cc.splitting_ops),
|
||||
"cudagraph_mode": str(cc.cudagraph_mode),
|
||||
"compile_sizes": list_to_str(cc.compile_sizes),
|
||||
"compile_ranges_split_points": list_to_str(
|
||||
cc.compile_ranges_split_points
|
||||
),
|
||||
"use_inductor_graph_partition": cc.use_inductor_graph_partition,
|
||||
"inductor_passes": list_to_str(list(cc.inductor_passes.keys())),
|
||||
"enabled_passes": list_to_str(enabled_passes),
|
||||
"dynamic_shapes_type": str(cc.dynamic_shapes_config.type),
|
||||
"dynamic_shapes_evaluate_guards": cc.dynamic_shapes_config.evaluate_guards, # noqa: E501
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
|
||||
from .caching import (
|
||||
VllmSerializableFunction,
|
||||
)
|
||||
|
||||
vllm_config = self.vllm_config
|
||||
|
||||
self._log_compilation_config()
|
||||
|
||||
# Minimal hashing here with existing utilities, reused below.
|
||||
|
||||
env_factors = envs.compile_factors()
|
||||
@@ -892,6 +943,13 @@ class VllmBackend:
|
||||
lazy_format_graph_code("before split", self.graph)
|
||||
lazy_format_graph_code("after split", self.split_gm)
|
||||
|
||||
# Log the piecewise split graph for TORCH_TRACE/tlparse
|
||||
trace_structured(
|
||||
"graph_dump",
|
||||
metadata_fn=lambda: {"name": "vllm_piecewise_split_graph"},
|
||||
payload_fn=lambda: self.split_gm.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
|
||||
submod_names_to_compile = [
|
||||
item.submod_name
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import pickle
|
||||
from collections.abc import Callable
|
||||
from pickle import Pickler
|
||||
@@ -11,6 +12,7 @@ from typing import Any
|
||||
import torch._functorch.config
|
||||
import torch.fx as fx
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
from torch._logging._internal import trace_structured
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
@@ -39,6 +41,7 @@ class PiecewiseBackend:
|
||||
vllm_backend: VllmBackend,
|
||||
returns_tuple: bool,
|
||||
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
|
||||
submod_name: str = "",
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
@@ -70,6 +73,7 @@ class PiecewiseBackend:
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
self.compiled_runnables = compiled_runnables
|
||||
self.submod_name = submod_name
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
@@ -131,6 +135,9 @@ class PiecewiseBackend:
|
||||
compile_range=range,
|
||||
)
|
||||
|
||||
# Track whether we've logged the graph for this subgraph (only log once)
|
||||
self._graph_logged = False
|
||||
|
||||
# get the on_compilation_complete callback from context...
|
||||
# PiecewiseBackend is created during the first call,
|
||||
# which is when the context is set (see compilation/decorators.py)
|
||||
@@ -221,6 +228,45 @@ class PiecewiseBackend:
|
||||
assert len(fake_example_inputs) == len(args)
|
||||
return fake_example_inputs
|
||||
|
||||
def _log_compile_start(self, compile_range: Range):
|
||||
"""Log compilation event for TORCH_TRACE/tlparse."""
|
||||
is_cudagraph_size = (
|
||||
self.compile_sizes is not None and compile_range.start in self.compile_sizes
|
||||
)
|
||||
subgraph_index = self.piecewise_compile_index
|
||||
submod_name = self.submod_name
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "vllm_piecewise_compile_start",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(
|
||||
{
|
||||
"piecewise_index": subgraph_index,
|
||||
"submod_name": submod_name,
|
||||
"total_piecewise_compiles": self.total_piecewise_compiles,
|
||||
"compile_range_start": compile_range.start,
|
||||
"compile_range_end": compile_range.end,
|
||||
"is_single_size": compile_range.is_single_size(),
|
||||
"is_cudagraph_capture_size": is_cudagraph_size,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Log the subgraph graph dump only once per subgraph (not per size)
|
||||
# to reduce log file size. The graph code is the same for all sizes.
|
||||
if not self._graph_logged:
|
||||
self._graph_logged = True
|
||||
assert self.graph is not None
|
||||
trace_structured(
|
||||
"graph_dump",
|
||||
metadata_fn=lambda: {
|
||||
"name": f"vllm_{submod_name}",
|
||||
},
|
||||
payload_fn=lambda: self.graph.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
def _maybe_compile_for_range_entry(
|
||||
self, range_entry: RangeEntry, args: tuple[Any, ...]
|
||||
) -> Any:
|
||||
@@ -230,6 +276,8 @@ class PiecewiseBackend:
|
||||
self.compiled_runnables[str(range_entry.compile_range)]
|
||||
)
|
||||
else:
|
||||
self._log_compile_start(range_entry.compile_range)
|
||||
|
||||
# args are real arguments
|
||||
# fakify for range, real args for concrete size.
|
||||
# For concrete size, we clear the shape env in
|
||||
|
||||
Reference in New Issue
Block a user