[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -8,8 +8,7 @@ from typing import Callable, List, Optional
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
from .levels import CompilationLevel
|
||||
from vllm.config import CompilationLevel
|
||||
|
||||
|
||||
class TorchCompileWrapperWithCustomDispatcher:
|
||||
@@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
`torch.compile` over the forward method.
|
||||
"""
|
||||
|
||||
def __init__(self, compiled_callable: Optional[Callable] = None):
|
||||
def __init__(self,
|
||||
compiled_callable: Optional[Callable] = None,
|
||||
compilation_level: int = 0):
|
||||
|
||||
if compiled_callable is None:
|
||||
# default compilation settings
|
||||
@@ -38,7 +39,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
backend = get_torch_compile_backend()
|
||||
if backend is None:
|
||||
from vllm.compilation.backends import select_default_backend
|
||||
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL)
|
||||
backend = select_default_backend(compilation_level)
|
||||
|
||||
compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
@@ -54,7 +55,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
# subclasses can use this to switch between the custom dispatcher
|
||||
# and the default Dynamo guard mechanism.
|
||||
self.use_custom_dispatcher: bool = \
|
||||
envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE
|
||||
compilation_level >= CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||
|
||||
Reference in New Issue
Block a user