[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-16 18:02:14 -08:00
committed by GitHub
parent 661a34fd4f
commit 4fd9375028
27 changed files with 359 additions and 283 deletions

View File

@@ -3,11 +3,13 @@ from typing import List
import pytest
from vllm.config import CompilationConfig, VllmConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.plugins import set_current_vllm_config
# Registered subclass for test
@@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation):
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_CUSTOM_OPS"] = env
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on
# Reset default_on (computed once):
CustomOp.default_on.cache_clear()
ops_enabled = [bool(x) for x in ops_enabled]
assert CustomOp.default_on() == default_on
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
ops_enabled = [bool(x) for x in ops_enabled]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str):
os.environ["VLLM_CUSTOM_OPS"] = env
CustomOp.default_on.cache_clear()
with pytest.raises(AssertionError):
RMSNorm(1024).enabled()
with pytest.raises(Exception): # noqa
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()