[compile] Fix torch.compile time discrepancy in logging. (#34912)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Zhengxu Chen
2026-02-20 11:47:14 -05:00
committed by GitHub
parent e4a5d8c653
commit f863994084
3 changed files with 17 additions and 8 deletions

View File

@@ -249,7 +249,7 @@ class CompilerManager:
if graph_index == 0: if graph_index == 0:
# before compiling the first graph, record the start time # before compiling the first graph, record the start time
global compilation_start_time global compilation_start_time
compilation_start_time = time.time() compilation_start_time = time.perf_counter()
compilation_counter.num_backend_compilations += 1 compilation_counter.num_backend_compilations += 1
@@ -261,8 +261,7 @@ class CompilerManager:
if graph_index == num_graphs - 1: if graph_index == num_graphs - 1:
# after loading the last graph for this shape, record the time. # after loading the last graph for this shape, record the time.
# there can be multiple graphs due to piecewise compilation. # there can be multiple graphs due to piecewise compilation.
now = time.time() elapsed = time.perf_counter() - compilation_start_time
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed compilation_config.compilation_time += elapsed
logger.info_once( logger.info_once(
"Directly load the compiled graph(s) for compile range %s " "Directly load the compiled graph(s) for compile range %s "
@@ -362,8 +361,7 @@ class CompilerManager:
# after compiling the last graph, record the end time # after compiling the last graph, record the end time
if graph_index == num_graphs - 1: if graph_index == num_graphs - 1:
now = time.time() elapsed = time.perf_counter() - compilation_start_time
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed compilation_config.compilation_time += elapsed
logger.info_once( logger.info_once(
"Compiling a graph for compile range %s takes %.2f s", "Compiling a graph for compile range %s takes %.2f s",
@@ -974,7 +972,7 @@ class VllmBackend:
compilation_counter.num_graphs_seen += 1 compilation_counter.num_graphs_seen += 1
from .monitor import torch_compile_start_time from .monitor import torch_compile_start_time
dynamo_time = time.time() - torch_compile_start_time dynamo_time = time.perf_counter() - torch_compile_start_time
logger.info_once( logger.info_once(
"Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local" "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
) )

View File

@@ -14,7 +14,7 @@ torch_compile_start_time: float = 0.0
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None: def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
global torch_compile_start_time global torch_compile_start_time
torch_compile_start_time = time.time() torch_compile_start_time = time.perf_counter()
compilation_config: CompilationConfig = vllm_config.compilation_config compilation_config: CompilationConfig = vllm_config.compilation_config
path = vllm_config.compile_debug_dump_path() path = vllm_config.compile_debug_dump_path()
@@ -30,10 +30,11 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None: def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
compilation_config: CompilationConfig = vllm_config.compilation_config compilation_config: CompilationConfig = vllm_config.compilation_config
total_compile_time: float = time.perf_counter() - torch_compile_start_time
if compilation_config.mode == CompilationMode.VLLM_COMPILE: if compilation_config.mode == CompilationMode.VLLM_COMPILE:
logger.info_once( logger.info_once(
"torch.compile takes %.2f s in total", "torch.compile takes %.2f s in total",
compilation_config.compilation_time, total_compile_time,
scope="local", scope="local",
) )
global context_manager global context_manager

View File

@@ -5,6 +5,7 @@ import dataclasses
import io import io
import json import json
import pickle import pickle
import time
from collections.abc import Callable from collections.abc import Callable
from pickle import Pickler from pickle import Pickler
from typing import Any from typing import Any
@@ -164,7 +165,16 @@ class PiecewiseBackend:
if self.is_last_graph and not self.to_be_compiled_ranges: if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile # no specific sizes to compile
# save the hash of the inductor graph for the next run # save the hash of the inductor graph for the next run
time_before_saving = time.perf_counter()
self.vllm_backend.compiler_manager.save_to_file() self.vllm_backend.compiler_manager.save_to_file()
elapsed = time.perf_counter() - time_before_saving
if elapsed > 1:
logger.info_once(
"Saved compiler manager cache in %.2f seconds.",
elapsed,
scope="local",
)
end_monitoring_torch_compile(self.vllm_config) end_monitoring_torch_compile(self.vllm_config)
# Call the completion callback (e.g., to save AOT compiled function) # Call the completion callback (e.g., to save AOT compiled function)
if self.on_compilation_complete is not None: if self.on_compilation_complete is not None: