[Frontend][torch.compile] CompilationConfig Overhaul (#20283): Set up -O infrastructure (#26847)

Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: adabeyta <aabeyta@redhat.com>
Signed-off-by: Morrison Turnansky <mturnans@redhat.com>
Co-authored-by: adabeyta <aabeyta@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Morrison Turnansky
2025-11-27 04:55:58 -05:00
committed by GitHub
parent 00d3310d2d
commit 0838b52e2e
13 changed files with 735 additions and 64 deletions

View File

@@ -8,9 +8,20 @@ from unittest.mock import patch
import pytest
from vllm.compilation.backends import VllmBackend
from vllm.config import ModelConfig, PoolerConfig, VllmConfig, update_config
from vllm.config import (
CompilationConfig,
ModelConfig,
PoolerConfig,
VllmConfig,
update_config,
)
from vllm.config.compilation import CompilationMode, CUDAGraphMode
from vllm.config.load import LoadConfig
from vllm.config.utils import get_field
from vllm.config.vllm import (
OPTIMIZATION_LEVEL_TO_CONFIG,
OptimizationLevel,
)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@@ -235,6 +246,43 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
assert model_config.pooler_config.pooling_type == pooling_type
@pytest.mark.parametrize(
("model_id", "expected_is_moe_model"),
[
("RedHatAI/Qwen3-8B-speculator.eagle3", False),
("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", False),
("RedHatAI/Llama-3.2-1B-FP8", False),
("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", False),
("RedHatAI/gpt-oss-20b", True),
("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
("RedHatAI/Llama-4-Scout-17B-16E-Instruct", True),
("RedHatAI/Mixtral-8x7B-Instruct-v0.1", True),
],
)
def test_moe_model_detection(model_id, expected_is_moe_model):
model_config = ModelConfig(model_id)
# Just check that is_moe_model field exists and is a boolean
assert model_config.is_model_moe() == expected_is_moe_model
@pytest.mark.parametrize(
("model_id", "quantized"),
[
("RedHatAI/Qwen3-8B-speculator.eagle3", False),
("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", True),
("RedHatAI/Llama-3.2-1B-FP8", True),
("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", True),
("RedHatAI/gpt-oss-20b", True),
("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
("RedHatAI/Mixtral-8x7B-Instruct-v0.1", False),
],
)
def test_is_quantized(model_id, quantized):
model_config = ModelConfig(model_id)
# Just check that quantized field exists and is a boolean
assert model_config.is_quantized() == quantized
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
@@ -552,3 +600,260 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
assert os.path.exists(config1.tokenizer) and os.path.isdir(config1.tokenizer)
assert os.path.exists(config2.model) and os.path.isdir(config2.model)
assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
@pytest.mark.parametrize(
("backend", "custom_ops", "expected"),
[
("eager", [], True),
("eager", ["+fused_layernorm"], True),
("eager", ["all", "-fused_layernorm"], False),
("inductor", [], False),
("inductor", ["none", "+fused_layernorm"], True),
("inductor", ["none", "-fused_layernorm"], False),
],
)
def test_is_custom_op_enabled(backend: str, custom_ops: list[str], expected: bool):
"""Test that is_custom_op_enabled works correctly."""
config = VllmConfig(
compilation_config=CompilationConfig(backend=backend, custom_ops=custom_ops)
)
assert config.compilation_config.is_custom_op_enabled("fused_layernorm") is expected
def test_vllm_config_defaults_are_none():
"""Verify that optimization-level defaults are None when not set by user."""
# Test all optimization levels to ensure defaults work correctly
for opt_level in OptimizationLevel:
config = object.__new__(VllmConfig)
config.compilation_config = CompilationConfig()
config.optimization_level = opt_level
config.model_config = None
# Use the global optimization level defaults
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[opt_level]
# Verify that all pass_config values are None before defaults are applied
for pass_k in default_config["compilation_config"]["pass_config"]:
assert getattr(config.compilation_config.pass_config, pass_k) is None
# Verify that other config values are None before defaults are applied
for k in default_config["compilation_config"]:
if k != "pass_config":
assert getattr(config.compilation_config, k) is None
@pytest.mark.parametrize(
("model_id", "compiliation_config", "optimization_level"),
[
(
None,
CompilationConfig(backend="eager", custom_ops=["+quant_fp8"]),
OptimizationLevel.O0,
),
(None, CompilationConfig(), OptimizationLevel.O0),
(None, CompilationConfig(), OptimizationLevel.O1),
(None, CompilationConfig(), OptimizationLevel.O2),
(None, CompilationConfig(), OptimizationLevel.O3),
(
"RedHatAI/Qwen3-8B-speculator.eagle3",
CompilationConfig(backend="inductor", custom_ops=["+quant_fp8"]),
OptimizationLevel.O2,
),
(
"RedHatAI/Qwen3-8B-speculator.eagle3",
CompilationConfig(),
OptimizationLevel.O0,
),
(
"RedHatAI/Qwen3-8B-speculator.eagle3",
CompilationConfig(),
OptimizationLevel.O1,
),
(
"RedHatAI/Qwen3-8B-speculator.eagle3",
CompilationConfig(),
OptimizationLevel.O2,
),
(
"RedHatAI/Qwen3-8B-speculator.eagle3",
CompilationConfig(),
OptimizationLevel.O3,
),
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O0),
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O1),
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O2),
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O3),
],
)
def test_vllm_config_defaults(model_id, compiliation_config, optimization_level):
"""Test that optimization-level defaults are correctly applied."""
model_config = None
if model_id is not None:
model_config = ModelConfig(model_id)
vllm_config = VllmConfig(
model_config=model_config,
compilation_config=compiliation_config,
optimization_level=optimization_level,
)
else:
vllm_config = VllmConfig(
compilation_config=compiliation_config,
optimization_level=optimization_level,
)
# Use the global optimization level defaults
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[optimization_level]
# Verify pass_config defaults (nested under compilation_config)
pass_config_dict = default_config["compilation_config"]["pass_config"]
for pass_k, pass_v in pass_config_dict.items():
actual = getattr(vllm_config.compilation_config.pass_config, pass_k)
expected = pass_v(vllm_config) if callable(pass_v) else pass_v
assert actual == expected, (
f"pass_config.{pass_k}: expected {expected}, got {actual}"
)
# Verify other compilation_config defaults
compilation_config_dict = default_config["compilation_config"]
for k, v in compilation_config_dict.items():
if k != "pass_config":
actual = getattr(vllm_config.compilation_config, k)
expected = v(vllm_config) if callable(v) else v
assert actual == expected, (
f"compilation_config.{k}: expected {expected}, got {actual}"
)
def test_vllm_config_callable_defaults():
"""Test that callable defaults work in the config system.
Verifies that lambdas in default configs can inspect VllmConfig properties
(e.g., is_quantized, is_model_moe) to conditionally set optimization flags.
"""
config_no_model = VllmConfig(optimization_level=OptimizationLevel.O2)
# Callable that checks if model exists
has_model = lambda cfg: cfg.model_config is not None
assert has_model(config_no_model) is False
# Test with quantized model
quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
config_quantized = VllmConfig(
model_config=quantized_model, optimization_level=OptimizationLevel.O2
)
enable_if_quantized = lambda cfg: (
cfg.model_config is not None and cfg.model_config.is_quantized()
)
assert enable_if_quantized(config_quantized) is True
assert enable_if_quantized(config_no_model) is False
# Test with MoE model
moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
config_moe = VllmConfig(
model_config=moe_model, optimization_level=OptimizationLevel.O2
)
enable_if_sequential = lambda cfg: (
cfg.model_config is not None and not cfg.model_config.is_model_moe()
)
assert enable_if_sequential(config_moe) is False
assert enable_if_sequential(config_quantized) is True
def test_vllm_config_explicit_overrides():
"""Test that explicit property overrides work correctly with callable defaults.
When users explicitly set configuration properties, those values
take precedence over callable defaults, across different models and
optimization levels.
"""
from vllm.config.compilation import PassConfig
quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
regular_model = ModelConfig("Qwen/Qwen1.5-7B")
# Explicit compilation mode override on O0 (where default is NONE)
compilation_config = CompilationConfig(mode=CompilationMode.VLLM_COMPILE)
config = VllmConfig(
optimization_level=OptimizationLevel.O0,
compilation_config=compilation_config,
)
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
# Explicit pass config flags to override defaults
pass_config = PassConfig(enable_noop=True, enable_attn_fusion=True)
compilation_config = CompilationConfig(pass_config=pass_config)
config = VllmConfig(
optimization_level=OptimizationLevel.O0,
compilation_config=compilation_config,
)
assert config.compilation_config.pass_config.enable_noop is True
assert config.compilation_config.pass_config.enable_attn_fusion is True
# Explicit cudagraph mode override on quantized model at O2
pass_config = PassConfig(enable_async_tp=True)
compilation_config = CompilationConfig(
cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
)
config = VllmConfig(
model_config=quantized_model,
optimization_level=OptimizationLevel.O2,
compilation_config=compilation_config,
)
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert config.compilation_config.pass_config.enable_async_tp is True
# Mode should still use default for O2
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
# Different optimization levels with same model
config_o0 = VllmConfig(
model_config=regular_model, optimization_level=OptimizationLevel.O0
)
config_o2 = VllmConfig(
model_config=regular_model, optimization_level=OptimizationLevel.O2
)
assert config_o0.compilation_config.mode == CompilationMode.NONE
assert config_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config_o0.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert (
config_o2.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
)
# Same optimization level across different model types
config_moe_o2 = VllmConfig(
model_config=moe_model, optimization_level=OptimizationLevel.O2
)
config_regular_o2 = VllmConfig(
model_config=regular_model, optimization_level=OptimizationLevel.O2
)
config_quantized_o2 = VllmConfig(
model_config=quantized_model, optimization_level=OptimizationLevel.O2
)
# All should have same base compilation settings at O2
assert config_moe_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config_regular_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config_quantized_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert (
config_moe_o2.compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_AND_PIECEWISE
)
assert (
config_regular_o2.compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_AND_PIECEWISE
)
# Override one field but not others
pass_config = PassConfig(enable_noop=False)
compilation_config = CompilationConfig(pass_config=pass_config)
config = VllmConfig(
model_config=regular_model,
optimization_level=OptimizationLevel.O2,
compilation_config=compilation_config,
)
# Explicit override should be respected
assert config.compilation_config.pass_config.enable_noop is False
# Other fields should still use defaults
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE