[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -3,11 +3,13 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||
ReLUSquaredActivation,
|
||||
SiluAndMul)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
|
||||
|
||||
# Registered subclass for test
|
||||
@@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation):
|
||||
])
|
||||
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
|
||||
default_on: bool):
|
||||
os.environ["VLLM_CUSTOM_OPS"] = env
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
custom_ops=env.split(",")))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
assert CustomOp.default_on() == default_on
|
||||
|
||||
# Reset default_on (computed once):
|
||||
CustomOp.default_on.cache_clear()
|
||||
ops_enabled = [bool(x) for x in ops_enabled]
|
||||
|
||||
assert CustomOp.default_on() == default_on
|
||||
assert RMSNorm(1024).enabled() == ops_enabled[0]
|
||||
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
|
||||
|
||||
ops_enabled = [bool(x) for x in ops_enabled]
|
||||
assert SiluAndMul().enabled() == ops_enabled[1]
|
||||
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
||||
|
||||
assert RMSNorm(1024).enabled() == ops_enabled[0]
|
||||
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
|
||||
assert GeluAndMul().enabled() == ops_enabled[2]
|
||||
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
||||
|
||||
assert SiluAndMul().enabled() == ops_enabled[1]
|
||||
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
||||
# If registered, subclasses should follow their own name
|
||||
assert Relu3().enabled() == ops_enabled[3]
|
||||
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
|
||||
|
||||
assert GeluAndMul().enabled() == ops_enabled[2]
|
||||
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
||||
# Unregistered subclass
|
||||
class SiluAndMul2(SiluAndMul):
|
||||
pass
|
||||
|
||||
# If registered, subclasses should follow their own name
|
||||
assert Relu3().enabled() == ops_enabled[3]
|
||||
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
|
||||
|
||||
# Unregistered subclass
|
||||
class SiluAndMul2(SiluAndMul):
|
||||
pass
|
||||
|
||||
# Subclasses should not require registration
|
||||
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
|
||||
# Subclasses should not require registration
|
||||
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
|
||||
def test_enabled_ops_invalid(env: str):
|
||||
os.environ["VLLM_CUSTOM_OPS"] = env
|
||||
CustomOp.default_on.cache_clear()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
RMSNorm(1024).enabled()
|
||||
with pytest.raises(Exception): # noqa
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
custom_ops=env.split(",")))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
RMSNorm(1024).enabled()
|
||||
|
||||
Reference in New Issue
Block a user