[torch.compile] transparent compilation with more logging (#12246)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -524,6 +524,7 @@ class VllmBackend:
|
|||||||
|
|
||||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||||
|
|
||||||
|
vllm_config = self.vllm_config
|
||||||
if not self.compilation_config.cache_dir:
|
if not self.compilation_config.cache_dir:
|
||||||
# no provided cache dir, generate one based on the known factors
|
# no provided cache dir, generate one based on the known factors
|
||||||
# that affects the compilation. if none of the factors change,
|
# that affects the compilation. if none of the factors change,
|
||||||
@@ -532,7 +533,6 @@ class VllmBackend:
|
|||||||
|
|
||||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||||
# model is created)
|
# model is created)
|
||||||
vllm_config = self.vllm_config
|
|
||||||
config_hash = vllm_config.compute_hash()
|
config_hash = vllm_config.compute_hash()
|
||||||
|
|
||||||
# 2. factors come from the code files that are traced by Dynamo (
|
# 2. factors come from the code files that are traced by Dynamo (
|
||||||
@@ -556,20 +556,26 @@ class VllmBackend:
|
|||||||
hash_key = hashlib.md5(
|
hash_key = hashlib.md5(
|
||||||
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
|
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
|
||||||
cache_dir = os.path.join(
|
cache_dir = os.path.join(
|
||||||
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
|
envs.VLLM_CACHE_ROOT,
|
||||||
f"rank_{vllm_config.parallel_config.rank}")
|
"torch_compile_cache",
|
||||||
else:
|
hash_key,
|
||||||
cache_dir = self.compilation_config.cache_dir
|
)
|
||||||
|
self.compilation_config.cache_dir = cache_dir
|
||||||
|
|
||||||
|
cache_dir = self.compilation_config.cache_dir
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
local_cache_dir = os.path.join(
|
||||||
|
cache_dir, f"rank_{vllm_config.parallel_config.rank}")
|
||||||
|
self.compilation_config.local_cache_dir = local_cache_dir
|
||||||
|
|
||||||
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
|
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
|
||||||
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
|
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
|
||||||
cache_dir, disabled=disabled)
|
local_cache_dir, disabled=disabled)
|
||||||
if disabled:
|
if disabled:
|
||||||
logger.info("vLLM's torch.compile cache is disabled.")
|
logger.info("vLLM's torch.compile cache is disabled.")
|
||||||
else:
|
else:
|
||||||
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
||||||
cache_dir)
|
local_cache_dir)
|
||||||
|
|
||||||
# when dynamo calls the backend, it means the bytecode
|
# when dynamo calls the backend, it means the bytecode
|
||||||
# transform and analysis are done
|
# transform and analysis are done
|
||||||
@@ -609,6 +615,18 @@ class VllmBackend:
|
|||||||
self.vllm_config, self.graph_pool,
|
self.vllm_config, self.graph_pool,
|
||||||
self).run(*example_inputs)
|
self).run(*example_inputs)
|
||||||
|
|
||||||
|
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||||
|
if not os.path.exists(graph_path):
|
||||||
|
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
|
||||||
|
# use `print_readable` because it can include submodules
|
||||||
|
src = "from __future__ import annotations\nimport torch\n" + \
|
||||||
|
self.split_gm.print_readable(print_output=False)
|
||||||
|
src = src.replace("<lambda>", "GraphModule")
|
||||||
|
with open(graph_path, "w") as f:
|
||||||
|
f.write(src)
|
||||||
|
|
||||||
|
logger.debug("Computation graph saved to %s", graph_path)
|
||||||
|
|
||||||
self._called = True
|
self._called = True
|
||||||
|
|
||||||
if not self.compilation_config.use_cudagraph or \
|
if not self.compilation_config.use_cudagraph or \
|
||||||
|
|||||||
@@ -198,6 +198,8 @@ def _support_torch_compile(
|
|||||||
f" {dims} for argument {k} with type {type(arg)}.")
|
f" {dims} for argument {k} with type {type(arg)}.")
|
||||||
# here, it is the starting point of the `torch.compile` process
|
# here, it is the starting point of the `torch.compile` process
|
||||||
start_monitoring_torch_compile(self.vllm_config)
|
start_monitoring_torch_compile(self.vllm_config)
|
||||||
|
logger.debug("Start compiling function %s",
|
||||||
|
self.original_code_object)
|
||||||
|
|
||||||
# if we don't use custom dispatcher, we can directly call the
|
# if we don't use custom dispatcher, we can directly call the
|
||||||
# compiled function and let torch.compile handle the dispatching,
|
# compiled function and let torch.compile handle the dispatching,
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileWrapperWithCustomDispatcher:
|
class TorchCompileWrapperWithCustomDispatcher:
|
||||||
@@ -82,6 +85,25 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.compiled_codes.append(new_code)
|
self.compiled_codes.append(new_code)
|
||||||
|
local_cache_dir = self.vllm_config.compilation_config.local_cache_dir
|
||||||
|
if isinstance(local_cache_dir, str):
|
||||||
|
decompiled_file = os.path.join(local_cache_dir,
|
||||||
|
"transformed_code.py")
|
||||||
|
if not os.path.exists(decompiled_file):
|
||||||
|
try:
|
||||||
|
# usually the decompilation will succeed for most models,
|
||||||
|
# as we guarantee a full-graph compilation in Dynamo.
|
||||||
|
# but there's no 100% guarantee, since decompliation is
|
||||||
|
# not a reversible process.
|
||||||
|
import depyf
|
||||||
|
src = depyf.decompile(new_code)
|
||||||
|
with open(decompiled_file, "w") as f:
|
||||||
|
f.write(src)
|
||||||
|
|
||||||
|
logger.debug("Dynamo transformed code saved to %s",
|
||||||
|
decompiled_file)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
if self.vllm_config.compilation_config.use_cudagraph and \
|
if self.vllm_config.compilation_config.use_cudagraph and \
|
||||||
"update" in new_code.co_names:
|
"update" in new_code.co_names:
|
||||||
|
|||||||
@@ -2785,6 +2785,7 @@ class CompilationConfig(BaseModel):
|
|||||||
compile_sizes: List[int] = PrivateAttr
|
compile_sizes: List[int] = PrivateAttr
|
||||||
capture_sizes: List[int] = PrivateAttr
|
capture_sizes: List[int] = PrivateAttr
|
||||||
max_capture_size: int = PrivateAttr
|
max_capture_size: int = PrivateAttr
|
||||||
|
local_cache_dir: str = PrivateAttr # local cache dir for each rank
|
||||||
# optimization:
|
# optimization:
|
||||||
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
|
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
|
||||||
# since we know all keys are in a range [0, max_capture_size],
|
# since we know all keys are in a range [0, max_capture_size],
|
||||||
|
|||||||
Reference in New Issue
Block a user