[compile] Fix torch.compile time discrepancy in logging. (#34912)
Signed-off-by: zhxchen17 <zhxchen17@fb.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -249,7 +249,7 @@ class CompilerManager:
|
|||||||
if graph_index == 0:
|
if graph_index == 0:
|
||||||
# before compiling the first graph, record the start time
|
# before compiling the first graph, record the start time
|
||||||
global compilation_start_time
|
global compilation_start_time
|
||||||
compilation_start_time = time.time()
|
compilation_start_time = time.perf_counter()
|
||||||
|
|
||||||
compilation_counter.num_backend_compilations += 1
|
compilation_counter.num_backend_compilations += 1
|
||||||
|
|
||||||
@@ -261,8 +261,7 @@ class CompilerManager:
|
|||||||
if graph_index == num_graphs - 1:
|
if graph_index == num_graphs - 1:
|
||||||
# after loading the last graph for this shape, record the time.
|
# after loading the last graph for this shape, record the time.
|
||||||
# there can be multiple graphs due to piecewise compilation.
|
# there can be multiple graphs due to piecewise compilation.
|
||||||
now = time.time()
|
elapsed = time.perf_counter() - compilation_start_time
|
||||||
elapsed = now - compilation_start_time
|
|
||||||
compilation_config.compilation_time += elapsed
|
compilation_config.compilation_time += elapsed
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Directly load the compiled graph(s) for compile range %s "
|
"Directly load the compiled graph(s) for compile range %s "
|
||||||
@@ -362,8 +361,7 @@ class CompilerManager:
|
|||||||
|
|
||||||
# after compiling the last graph, record the end time
|
# after compiling the last graph, record the end time
|
||||||
if graph_index == num_graphs - 1:
|
if graph_index == num_graphs - 1:
|
||||||
now = time.time()
|
elapsed = time.perf_counter() - compilation_start_time
|
||||||
elapsed = now - compilation_start_time
|
|
||||||
compilation_config.compilation_time += elapsed
|
compilation_config.compilation_time += elapsed
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Compiling a graph for compile range %s takes %.2f s",
|
"Compiling a graph for compile range %s takes %.2f s",
|
||||||
@@ -974,7 +972,7 @@ class VllmBackend:
|
|||||||
compilation_counter.num_graphs_seen += 1
|
compilation_counter.num_graphs_seen += 1
|
||||||
from .monitor import torch_compile_start_time
|
from .monitor import torch_compile_start_time
|
||||||
|
|
||||||
dynamo_time = time.time() - torch_compile_start_time
|
dynamo_time = time.perf_counter() - torch_compile_start_time
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
|
"Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ torch_compile_start_time: float = 0.0
|
|||||||
|
|
||||||
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||||
global torch_compile_start_time
|
global torch_compile_start_time
|
||||||
torch_compile_start_time = time.time()
|
torch_compile_start_time = time.perf_counter()
|
||||||
|
|
||||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||||
path = vllm_config.compile_debug_dump_path()
|
path = vllm_config.compile_debug_dump_path()
|
||||||
@@ -30,10 +30,11 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
|||||||
|
|
||||||
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||||
|
total_compile_time: float = time.perf_counter() - torch_compile_start_time
|
||||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"torch.compile takes %.2f s in total",
|
"torch.compile takes %.2f s in total",
|
||||||
compilation_config.compilation_time,
|
total_compile_time,
|
||||||
scope="local",
|
scope="local",
|
||||||
)
|
)
|
||||||
global context_manager
|
global context_manager
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import dataclasses
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pickle import Pickler
|
from pickle import Pickler
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -164,7 +165,16 @@ class PiecewiseBackend:
|
|||||||
if self.is_last_graph and not self.to_be_compiled_ranges:
|
if self.is_last_graph and not self.to_be_compiled_ranges:
|
||||||
# no specific sizes to compile
|
# no specific sizes to compile
|
||||||
# save the hash of the inductor graph for the next run
|
# save the hash of the inductor graph for the next run
|
||||||
|
time_before_saving = time.perf_counter()
|
||||||
self.vllm_backend.compiler_manager.save_to_file()
|
self.vllm_backend.compiler_manager.save_to_file()
|
||||||
|
elapsed = time.perf_counter() - time_before_saving
|
||||||
|
if elapsed > 1:
|
||||||
|
logger.info_once(
|
||||||
|
"Saved compiler manager cache in %.2f seconds.",
|
||||||
|
elapsed,
|
||||||
|
scope="local",
|
||||||
|
)
|
||||||
|
|
||||||
end_monitoring_torch_compile(self.vllm_config)
|
end_monitoring_torch_compile(self.vllm_config)
|
||||||
# Call the completion callback (e.g., to save AOT compiled function)
|
# Call the completion callback (e.g., to save AOT compiled function)
|
||||||
if self.on_compilation_complete is not None:
|
if self.on_compilation_complete is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user