[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-16 18:02:14 -08:00
committed by GitHub
parent 661a34fd4f
commit 4fd9375028
27 changed files with 359 additions and 283 deletions

View File

@@ -1,12 +1,10 @@
from functools import lru_cache
from typing import Dict, Type
import torch.nn as nn
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import print_warning_once
logger = init_logger(__name__)
@@ -87,6 +85,8 @@ class CustomOp(nn.Module):
@classmethod
def enabled(cls) -> bool:
# if no name, then it was not registered
compilation_config = get_current_vllm_config().compilation_config
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
print_warning_once(
f"Custom op {cls.__name__} was not registered, "
@@ -94,22 +94,25 @@ class CustomOp(nn.Module):
f"It will be enabled/disabled based on the global settings.")
return CustomOp.default_on()
enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
enabled = f"+{cls.name}" in custom_ops
disabled = f"-{cls.name}" in custom_ops
assert not (enabled
and disabled), f"Cannot enable and disable {cls.name}"
return (CustomOp.default_on() or enabled) and not disabled
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@staticmethod
@lru_cache
def default_on() -> bool:
count_none = envs.VLLM_CUSTOM_OPS.count("none")
count_all = envs.VLLM_CUSTOM_OPS.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \
"""
On by default if level < CompilationLevel.PIECEWISE
Specifying 'all' or 'none' in custom_op takes precedence.
"""
from vllm.config import CompilationLevel
compilation_config = get_current_vllm_config().compilation_config
custom_ops = compilation_config.custom_ops
count_none = custom_ops.count("none")
count_all = custom_ops.count("all")
return compilation_config.level < CompilationLevel.PIECEWISE and \
not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name).