[Core] Rename PassConfig flags as per RFC #27995 (#29646)

Signed-off-by: arpitkh101 <arpit5khandelwal@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Arpit Khandelwal
2025-12-02 22:38:55 -05:00
committed by GitHub
parent 506ed87e87
commit d7284a2604
22 changed files with 318 additions and 123 deletions

View File

@@ -153,7 +153,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
]
def ops_in_model(self):
if self.vllm_config.compilation_config.pass_config.enable_fusion:
if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
elif RMSNorm.enabled():
return [
@@ -183,7 +183,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("enable_fusion", [True, False])
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass(
@@ -193,7 +193,7 @@ def test_sequence_parallelism_pass(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
fuse_norm_quant: bool,
dynamic: bool,
):
num_processes = 2
@@ -211,7 +211,7 @@ def test_sequence_parallelism_pass(
seq_len,
hidden_size,
dtype,
enable_fusion,
fuse_norm_quant,
dynamic,
),
nprocs=nprocs,
@@ -229,7 +229,7 @@ def sequence_parallelism_pass_on_test_model(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
fuse_norm_quant: bool,
dynamic: bool,
):
current_platform.seed_everything(0)
@@ -260,9 +260,9 @@ def sequence_parallelism_pass_on_test_model(
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
custom_ops=custom_ops_list,
pass_config=PassConfig(
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True,
enable_sp=True,
fuse_norm_quant=fuse_norm_quant,
eliminate_noops=True,
),
) # NoOp needed for fusion
device_config = DeviceConfig(device=torch.device("cuda"))
@@ -297,7 +297,7 @@ def sequence_parallelism_pass_on_test_model(
sequence_parallelism_pass,
]
if enable_fusion:
if fuse_norm_quant:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)