[Fix][torch.compile] Enable custom ops by default when Inductor off (#20102)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@@ -28,42 +28,49 @@ class Relu3(ReLUSquaredActivation):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env, torch_level, ops_enabled, default_on",
|
||||
"env, torch_level, use_inductor, ops_enabled, default_on",
|
||||
[
|
||||
# Default values based on compile level
|
||||
("", 0, [True] * 4, True),
|
||||
("", 1, [True] * 4, True),
|
||||
("", 2, [True] * 4, True), # All by default
|
||||
("", 3, [False] * 4, False),
|
||||
("", 4, [False] * 4, False), # None by default
|
||||
# - All by default (no Inductor compilation)
|
||||
("", 0, False, [True] * 4, True),
|
||||
("", 1, True, [True] * 4, True),
|
||||
("", 2, False, [True] * 4, True),
|
||||
# - None by default (with Inductor)
|
||||
("", 3, True, [False] * 4, False),
|
||||
("", 4, True, [False] * 4, False),
|
||||
# - All by default (without Inductor)
|
||||
("", 3, False, [True] * 4, True),
|
||||
("", 4, False, [True] * 4, True),
|
||||
# Explicitly enabling/disabling
|
||||
#
|
||||
# Default: all
|
||||
#
|
||||
# All but SiluAndMul
|
||||
("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True),
|
||||
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
|
||||
# Only ReLU3
|
||||
("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False),
|
||||
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
|
||||
# All but SiluAndMul
|
||||
("all,-silu_and_mul", 1, [1, 0, 1, 1], True),
|
||||
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
|
||||
# All but ReLU3 (even if ReLU2 is on)
|
||||
("-relu3,relu2", 1, [1, 1, 1, 0], True),
|
||||
# GeluAndMul and SiluAndMul
|
||||
("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False),
|
||||
("-relu3,relu2", 3, False, [1, 1, 1, 0], True),
|
||||
# RMSNorm and SiluAndMul
|
||||
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
|
||||
# All but RMSNorm
|
||||
("-rms_norm", 2, [0, 1, 1, 1], True),
|
||||
("-rms_norm", 3, False, [0, 1, 1, 1], True),
|
||||
#
|
||||
# Default: none
|
||||
#
|
||||
# Only ReLU3
|
||||
("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False),
|
||||
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
|
||||
# All but RMSNorm
|
||||
("all,-rms_norm", 4, [0, 1, 1, 1], True),
|
||||
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
|
||||
])
|
||||
def test_enabled_ops(env: str, torch_level: int, ops_enabled: list[int],
|
||||
default_on: bool):
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=torch_level, custom_ops=env.split(",")))
|
||||
def test_enabled_ops(env: str, torch_level: int, use_inductor: bool,
|
||||
ops_enabled: list[int], default_on: bool):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
|
||||
level=torch_level,
|
||||
custom_ops=env.split(",")))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
assert CustomOp.default_on() == default_on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user