[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-16 18:02:14 -08:00
committed by GitHub
parent 661a34fd4f
commit 4fd9375028
27 changed files with 359 additions and 283 deletions

View File

@@ -8,8 +8,7 @@ from typing import Callable, List, Optional
import torch
import vllm.envs as envs
from .levels import CompilationLevel
from vllm.config import CompilationLevel
class TorchCompileWrapperWithCustomDispatcher:
@@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher:
`torch.compile` over the forward method.
"""
def __init__(self, compiled_callable: Optional[Callable] = None):
def __init__(self,
compiled_callable: Optional[Callable] = None,
compilation_level: int = 0):
if compiled_callable is None:
# default compilation settings
@@ -38,7 +39,7 @@ class TorchCompileWrapperWithCustomDispatcher:
backend = get_torch_compile_backend()
if backend is None:
from vllm.compilation.backends import select_default_backend
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL)
backend = select_default_backend(compilation_level)
compiled_callable = torch.compile(
self.forward,
@@ -54,7 +55,7 @@ class TorchCompileWrapperWithCustomDispatcher:
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = \
envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE
compilation_level >= CompilationLevel.DYNAMO_ONCE
def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.