Signed-off-by: arpitkh101 <arpit5khandelwal@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -153,7 +153,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
if self.vllm_config.compilation_config.pass_config.enable_fusion:
|
||||
if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
|
||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
||||
elif RMSNorm.enabled():
|
||||
return [
|
||||
@@ -183,7 +183,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
||||
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
|
||||
@pytest.mark.parametrize("dynamic", [False, True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_sequence_parallelism_pass(
|
||||
@@ -193,7 +193,7 @@ def test_sequence_parallelism_pass(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
fuse_norm_quant: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
num_processes = 2
|
||||
@@ -211,7 +211,7 @@ def test_sequence_parallelism_pass(
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
enable_fusion,
|
||||
fuse_norm_quant,
|
||||
dynamic,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
@@ -229,7 +229,7 @@ def sequence_parallelism_pass_on_test_model(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_fusion: bool,
|
||||
fuse_norm_quant: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
@@ -260,9 +260,9 @@ def sequence_parallelism_pass_on_test_model(
|
||||
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
|
||||
custom_ops=custom_ops_list,
|
||||
pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True,
|
||||
enable_sp=True,
|
||||
fuse_norm_quant=fuse_norm_quant,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
) # NoOp needed for fusion
|
||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
@@ -297,7 +297,7 @@ def sequence_parallelism_pass_on_test_model(
|
||||
sequence_parallelism_pass,
|
||||
]
|
||||
|
||||
if enable_fusion:
|
||||
if fuse_norm_quant:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user