[torch.compile] integration with compilation control (#9058)

This commit is contained in:
youkaichao
2024-10-10 12:39:36 -07:00
committed by GitHub
parent 78c0b4166c
commit e4d652ea3e
22 changed files with 404 additions and 98 deletions

View File

@@ -3,12 +3,14 @@ import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Callable, List
from typing import Callable, List, Optional
import torch
import vllm.envs as envs
from .levels import CompilationLevel
class TorchCompileWrapperWithCustomDispatcher:
"""
@@ -23,7 +25,26 @@ class TorchCompileWrapperWithCustomDispatcher:
`torch.compile` over the forward method.
"""
def __init__(self, compiled_callable: Callable):
def __init__(self, compiled_callable: Optional[Callable] = None):
if compiled_callable is None:
# default compilation settings
# compiling the forward method
# choose the compile backend
# if the user has set the backend, use it
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend()
if backend is None:
from vllm.compilation.backends import select_default_backend
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL)
compiled_callable = torch.compile(
self.forward,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend)
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: List[CodeType] = []
@@ -33,7 +54,7 @@ class TorchCompileWrapperWithCustomDispatcher:
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = \
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER
envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE
def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.