[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-16 18:02:14 -08:00
committed by GitHub
parent 661a34fd4f
commit 4fd9375028
27 changed files with 359 additions and 283 deletions

View File

@@ -1,10 +1,15 @@
import enum
import random
from typing import NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
@@ -129,6 +134,19 @@ class Platform:
np.random.seed(seed)
torch.manual_seed(seed)
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"""
Check and update the configuration for the current platform.
It can raise an exception if the configuration is not compatible with
the current platform, or it can update the configuration to make it
compatible with the current platform.
The config is passed by reference, so it can be modified in place.
"""
pass
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@@ -1,18 +1,16 @@
import os
from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_torch_compile_backend
from .interface import Platform, PlatformEnum
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
set_torch_compile_backend("openxla")
@@ -31,3 +29,12 @@ class TpuPlatform(Platform):
@classmethod
def inference_mode(cls):
return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel
compilation_config = vllm_config.compilation_config
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."