[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
import enum
|
||||
import random
|
||||
from typing import NamedTuple, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
CUDA = enum.auto()
|
||||
@@ -129,6 +134,19 @@ class Platform:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
"""
|
||||
Check and update the configuration for the current platform.
|
||||
|
||||
It can raise an exception if the configuration is not compatible with
|
||||
the current platform, or it can update the configuration to make it
|
||||
compatible with the current platform.
|
||||
|
||||
The config is passed by reference, so it can be modified in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.plugins import set_torch_compile_backend
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
|
||||
|
||||
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\
|
||||
"TPU does not support Inductor."
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
set_torch_compile_backend("openxla")
|
||||
|
||||
@@ -31,3 +29,12 @@ class TpuPlatform(Platform):
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
from vllm.config import CompilationLevel
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
|
||||
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
assert compilation_config.level < CompilationLevel.PIECEWISE,\
|
||||
"TPU does not support Inductor."
|
||||
|
||||
Reference in New Issue
Block a user