Signed-off-by: arpitkh101 <arpit5khandelwal@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
(cherry picked from commit d7284a2604)
This commit is contained in:
committed by
Kevin H. Luu
parent
a1d627e40f
commit
4fd9d6a85c
@@ -326,7 +326,7 @@ def async_tp_pass_on_test_model(
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(
|
||||
enable_async_tp=True,
|
||||
fuse_gemm_comms=True,
|
||||
),
|
||||
)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
@@ -413,7 +413,7 @@ def test_async_tp_pass_correctness(
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_sizes": [2, 4, 8],
|
||||
"splitting_ops": [],
|
||||
"pass_config": {"enable_async_tp": async_tp_enabled},
|
||||
"pass_config": {"fuse_gemm_comms": async_tp_enabled},
|
||||
}
|
||||
|
||||
async_tp_args = [
|
||||
|
||||
@@ -295,7 +295,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
)
|
||||
)
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_fi_allreduce_fusion=True, enable_noop=True
|
||||
fuse_allreduce_rms=True, eliminate_noops=True
|
||||
)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
|
||||
|
||||
@@ -192,7 +192,7 @@ def test_attn_quant(
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
@@ -282,9 +282,9 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(
|
||||
enable_attn_fusion=True,
|
||||
enable_noop=True,
|
||||
enable_fi_allreduce_fusion=True,
|
||||
fuse_attn_quant=True,
|
||||
eliminate_noops=True,
|
||||
fuse_allreduce_rms=True,
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
@@ -384,10 +384,10 @@ def test_tp2_attn_quant_async_tp(
|
||||
# Common
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(
|
||||
enable_attn_fusion=True,
|
||||
enable_noop=True,
|
||||
enable_sequence_parallelism=True,
|
||||
enable_async_tp=True,
|
||||
fuse_attn_quant=True,
|
||||
eliminate_noops=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
|
||||
@@ -153,7 +153,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
if self.vllm_config.compilation_config.pass_config.enable_fusion:
|
||||
if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
|
||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
||||
elif RMSNorm.enabled():
|
||||
return [
|
||||
@@ -183,7 +183,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
||||
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
|
||||
@pytest.mark.parametrize("dynamic", [False, True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_sequence_parallelism_pass(
|
||||
@@ -193,7 +193,7 @@ def test_sequence_parallelism_pass(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
fuse_norm_quant: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
num_processes = 2
|
||||
@@ -211,7 +211,7 @@ def test_sequence_parallelism_pass(
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
enable_fusion,
|
||||
fuse_norm_quant,
|
||||
dynamic,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
@@ -229,7 +229,7 @@ def sequence_parallelism_pass_on_test_model(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
fuse_norm_quant: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
@@ -260,9 +260,9 @@ def sequence_parallelism_pass_on_test_model(
|
||||
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
|
||||
custom_ops=custom_ops_list,
|
||||
pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True,
|
||||
enable_sp=True,
|
||||
fuse_norm_quant=fuse_norm_quant,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
) # NoOp needed for fusion
|
||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
@@ -297,7 +297,7 @@ def sequence_parallelism_pass_on_test_model(
|
||||
sequence_parallelism_pass,
|
||||
]
|
||||
|
||||
if enable_fusion:
|
||||
if fuse_norm_quant:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
|
||||
|
||||
@@ -122,7 +122,9 @@ def test_full_graph(
|
||||
CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+rms_norm"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||
),
|
||||
),
|
||||
*model_info,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -10,8 +11,9 @@ from pydantic import ValidationError
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.config.compilation import CompilationMode, PassConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.logger import _print_warning_once
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
|
||||
@@ -191,7 +193,7 @@ def test_splitting_ops_dynamic():
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
@@ -206,7 +208,7 @@ def test_splitting_ops_dynamic():
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
@@ -219,7 +221,7 @@ def test_splitting_ops_dynamic():
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
@@ -227,7 +229,7 @@ def test_splitting_ops_dynamic():
|
||||
# With inductor graph partition, attn_fusion and splitting_ops
|
||||
# work together. Default splitting_ops include attention ops.
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# enable_attn_fusion is directly supported under
|
||||
# fuse_attn_quant is directly supported under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
@@ -301,7 +303,7 @@ def test_should_split():
|
||||
"cudagraph_capture_sizes",
|
||||
"max_cudagraph_capture_size",
|
||||
"tp_size",
|
||||
"enable_sequence_parallelism",
|
||||
"enable_sp",
|
||||
"max_num_batched_tokens",
|
||||
"cudagraph_mode",
|
||||
"expected_max_size",
|
||||
@@ -339,7 +341,7 @@ def test_cudagraph_sizes_post_init(
|
||||
cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size,
|
||||
tp_size,
|
||||
enable_sequence_parallelism,
|
||||
enable_sp,
|
||||
max_num_batched_tokens,
|
||||
cudagraph_mode,
|
||||
expected_max_size,
|
||||
@@ -355,11 +357,12 @@ def test_cudagraph_sizes_post_init(
|
||||
compilation_config = CompilationConfig(
|
||||
cudagraph_capture_sizes=cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size=max_cudagraph_capture_size,
|
||||
pass_config={
|
||||
"enable_sequence_parallelism": enable_sequence_parallelism,
|
||||
"enable_fusion": True,
|
||||
"enable_noop": True,
|
||||
},
|
||||
pass_config=PassConfig(
|
||||
enable_sp=enable_sp,
|
||||
fuse_norm_quant=True,
|
||||
fuse_act_quant=True,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
)
|
||||
engine_args = EngineArgs(
|
||||
@@ -375,3 +378,53 @@ def test_cudagraph_sizes_post_init(
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
== expected_max_size
|
||||
)
|
||||
|
||||
|
||||
def test_pass_config_deprecation(caplog_vllm):
|
||||
caplog_vllm.set_level(logging.WARNING)
|
||||
|
||||
# Clear cache to ensure warnings are re-issued
|
||||
_print_warning_once.cache_clear()
|
||||
|
||||
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fusion=True)
|
||||
assert "enable_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_norm_quant is True
|
||||
assert config.fuse_act_quant is True
|
||||
assert config.enable_fusion is None
|
||||
|
||||
# Test enable_attn_fusion -> fuse_attn_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_attn_fusion=True)
|
||||
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_attn_quant is True
|
||||
assert config.enable_attn_fusion is None
|
||||
|
||||
# Test enable_noop -> eliminate_noops
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_noop=True)
|
||||
assert "enable_noop is deprecated" in caplog_vllm.text
|
||||
assert config.eliminate_noops is True
|
||||
assert config.enable_noop is None
|
||||
|
||||
# Test enable_sequence_parallelism -> enable_sp
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_sequence_parallelism=True)
|
||||
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
|
||||
assert config.enable_sp is True
|
||||
assert config.enable_sequence_parallelism is None
|
||||
|
||||
# Test enable_async_tp -> fuse_gemm_comms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_async_tp=True)
|
||||
assert "enable_async_tp is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_gemm_comms is True
|
||||
assert config.enable_async_tp is None
|
||||
|
||||
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fi_allreduce_fusion=True)
|
||||
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_allreduce_rms is True
|
||||
assert config.enable_fi_allreduce_fusion is None
|
||||
|
||||
@@ -223,7 +223,11 @@ def test_fix_functionalization(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
custom_ops=["all"],
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=do_fusion,
|
||||
fuse_act_quant=do_fusion,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -159,7 +159,9 @@ def test_fusion_rmsnorm_quant(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||
),
|
||||
),
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
|
||||
@@ -373,7 +373,7 @@ def test_attention_quant_pattern(
|
||||
|
||||
# Run model with attn fusion enabled
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_attn_fusion=True, enable_noop=True
|
||||
fuse_attn_quant=True, eliminate_noops=True
|
||||
)
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
pass_config=PassConfig(eliminate_noops=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
@@ -99,7 +99,7 @@ def test_non_noop_slice_preserved():
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
pass_config=PassConfig(eliminate_noops=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
|
||||
@@ -64,8 +64,11 @@ def test_pass_manager_uuid(callable):
|
||||
|
||||
# UUID should be different due to config change
|
||||
config2 = copy.deepcopy(config)
|
||||
config2.compilation_config.pass_config.enable_fusion = (
|
||||
not config2.compilation_config.pass_config.enable_fusion
|
||||
config2.compilation_config.pass_config.fuse_norm_quant = (
|
||||
not config2.compilation_config.pass_config.fuse_norm_quant
|
||||
)
|
||||
config2.compilation_config.pass_config.fuse_act_quant = (
|
||||
not config2.compilation_config.pass_config.fuse_act_quant
|
||||
)
|
||||
pass_manager3 = PostGradPassManager()
|
||||
pass_manager3.configure(config2)
|
||||
|
||||
@@ -140,7 +140,7 @@ def test_qk_norm_rope_fusion(
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_noop=True,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -168,7 +168,7 @@ def test_fusion_silu_and_mul_quant(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
enable_fusion: bool
|
||||
fuse_norm_quant: bool
|
||||
fuse_act_quant: bool
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
@@ -66,7 +67,8 @@ class SPTestSettings:
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
enable_fusion=False,
|
||||
fuse_norm_quant=False,
|
||||
fuse_act_quant=False,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val,
|
||||
)
|
||||
@@ -97,7 +99,8 @@ class SPTestSettings:
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
enable_fusion=False,
|
||||
fuse_norm_quant=False,
|
||||
fuse_act_quant=False,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val,
|
||||
)
|
||||
@@ -126,7 +129,8 @@ class SPTestSettings:
|
||||
ParallelSetup(
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
enable_fusion=fusion_val,
|
||||
fuse_norm_quant=fusion_val,
|
||||
fuse_act_quant=fusion_val,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False,
|
||||
)
|
||||
@@ -162,7 +166,7 @@ def _compare_sp(
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available: int,
|
||||
use_inductor_graph_partition: bool,
|
||||
enable_async_tp: bool,
|
||||
fuse_gemm_comms: bool,
|
||||
*,
|
||||
method: Literal["generate", "encode"],
|
||||
is_multimodal: bool,
|
||||
@@ -170,7 +174,8 @@ def _compare_sp(
|
||||
(
|
||||
tp_size,
|
||||
pp_size,
|
||||
enable_fusion,
|
||||
fuse_norm_quant,
|
||||
fuse_act_quant,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
@@ -248,10 +253,11 @@ def _compare_sp(
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_sizes": [4, 8],
|
||||
"pass_config": {
|
||||
"enable_sequence_parallelism": True,
|
||||
"enable_async_tp": enable_async_tp,
|
||||
"enable_fusion": enable_fusion,
|
||||
"enable_noop": True,
|
||||
"enable_sp": True,
|
||||
"fuse_gemm_comms": fuse_gemm_comms,
|
||||
"fuse_norm_quant": fuse_norm_quant,
|
||||
"fuse_act_quant": fuse_act_quant,
|
||||
"eliminate_noops": True,
|
||||
},
|
||||
"use_inductor_graph_partition": use_inductor_graph_partition,
|
||||
}
|
||||
@@ -309,7 +315,7 @@ SP_TEST_MODELS = [
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||
@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
|
||||
@pytest.mark.parametrize("fuse_gemm_comms", [False]) # TODO: enable async TP
|
||||
@create_new_process_for_each_test()
|
||||
def test_tp_sp_generation(
|
||||
model_id: str,
|
||||
@@ -319,7 +325,7 @@ def test_tp_sp_generation(
|
||||
test_options: SPTestOptions,
|
||||
num_gpus_available,
|
||||
use_inductor_graph_partition: bool,
|
||||
enable_async_tp: bool,
|
||||
fuse_gemm_comms: bool,
|
||||
):
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
@@ -328,7 +334,7 @@ def test_tp_sp_generation(
|
||||
if (
|
||||
"fp8" in model_id.lower()
|
||||
and current_platform.get_device_capability() < (9, 0)
|
||||
and (not enable_async_tp)
|
||||
and (not fuse_gemm_comms)
|
||||
):
|
||||
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
|
||||
|
||||
@@ -340,7 +346,7 @@ def test_tp_sp_generation(
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
use_inductor_graph_partition,
|
||||
enable_async_tp=enable_async_tp,
|
||||
fuse_gemm_comms=fuse_gemm_comms,
|
||||
method="generate",
|
||||
is_multimodal=False,
|
||||
)
|
||||
|
||||
@@ -1023,17 +1023,17 @@ def test_vllm_config_explicit_overrides():
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||
|
||||
# Explicit pass config flags to override defaults
|
||||
pass_config = PassConfig(enable_noop=True, enable_attn_fusion=True)
|
||||
pass_config = PassConfig(eliminate_noops=True, fuse_attn_quant=True)
|
||||
compilation_config = CompilationConfig(pass_config=pass_config)
|
||||
config = VllmConfig(
|
||||
optimization_level=OptimizationLevel.O0,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
assert config.compilation_config.pass_config.enable_noop is True
|
||||
assert config.compilation_config.pass_config.enable_attn_fusion is True
|
||||
assert config.compilation_config.pass_config.eliminate_noops is True
|
||||
assert config.compilation_config.pass_config.fuse_attn_quant is True
|
||||
|
||||
# Explicit cudagraph mode override on quantized model at O2
|
||||
pass_config = PassConfig(enable_async_tp=True)
|
||||
pass_config = PassConfig(fuse_gemm_comms=True)
|
||||
compilation_config = CompilationConfig(
|
||||
cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
|
||||
)
|
||||
@@ -1043,7 +1043,7 @@ def test_vllm_config_explicit_overrides():
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||
assert config.compilation_config.pass_config.enable_async_tp is True
|
||||
assert config.compilation_config.pass_config.fuse_gemm_comms is True
|
||||
# Mode should still use default for O2
|
||||
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
|
||||
@@ -1093,7 +1093,7 @@ def test_vllm_config_explicit_overrides():
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
# Explicit override should be respected
|
||||
assert config.compilation_config.pass_config.enable_noop is False
|
||||
assert config.compilation_config.pass_config.eliminate_noops is False
|
||||
# Other fields should still use defaults
|
||||
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
|
||||
Reference in New Issue
Block a user