[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> Signed-off-by: chzhang <chaojun.zhang@intel.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com> Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com> Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com> Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
This commit is contained in:
@@ -6,12 +6,14 @@ import os
|
||||
from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||
from unittest.mock import patch
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
KernelConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
PoolerConfig,
|
||||
@@ -21,6 +23,7 @@ from vllm.config import (
|
||||
update_config,
|
||||
)
|
||||
from vllm.config.compilation import CompilationMode, CUDAGraphMode
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.utils import get_field
|
||||
from vllm.config.vllm import (
|
||||
@@ -1077,6 +1080,39 @@ def test_vllm_config_explicit_overrides():
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
|
||||
|
||||
def test_fusion_pass_op_priority():
|
||||
"""This test checks that custom op enablement & IR op priority
|
||||
correctly control default fusions"""
|
||||
|
||||
# Default config, O2, rms_norm+quant fusion disabled
|
||||
cfg1 = VllmConfig()
|
||||
assert not cfg1.compilation_config.pass_config.fuse_norm_quant
|
||||
|
||||
# rms_norm manually enabled, O1, rms_norm+quant fusion enabled
|
||||
cfg2 = VllmConfig(
|
||||
optimization_level=OptimizationLevel.O1,
|
||||
compilation_config=CompilationConfig(
|
||||
custom_ops=["+rms_norm"],
|
||||
),
|
||||
)
|
||||
assert cfg2.compilation_config.pass_config.fuse_norm_quant
|
||||
|
||||
# using custom kernel for RMSNorm via IR:
|
||||
# Note that vLLM IR only supports the non-residual rms_norm for now;
|
||||
# soon this will be resolved.
|
||||
cfg3 = VllmConfig(
|
||||
kernel_config=KernelConfig(
|
||||
ir_op_priority=IrOpPriorityConfig(rms_norm=["vllm_c"])
|
||||
)
|
||||
)
|
||||
assert cfg3.compilation_config.pass_config.fuse_norm_quant
|
||||
|
||||
# block-fp8 model should enable quant_fp8 automatically
|
||||
cfg4 = VllmConfig(model_config=ModelConfig("Qwen/Qwen3-4B-FP8"))
|
||||
assert "+quant_fp8" in cfg4.compilation_config.custom_ops
|
||||
assert cfg4.compilation_config.pass_config.fuse_norm_quant
|
||||
|
||||
|
||||
def test_scheduler_config_init():
|
||||
with pytest.raises(ValidationError):
|
||||
# Positional InitVars missing
|
||||
@@ -1171,3 +1207,35 @@ def test_eagle_draft_model_config():
|
||||
assert draft_model_config.hf_text_config.model_type == "eagle"
|
||||
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
|
||||
assert draft_model_config.architecture == "EagleLlamaForCausalLM"
|
||||
|
||||
|
||||
def test_ir_op_priority_default():
|
||||
"""Test that IR op priority defaults are set correctly."""
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
|
||||
# Assert default is applied to ops
|
||||
priority_config = IrOpPriorityConfig.with_default(["vllm_c", "native"])
|
||||
assert priority_config.rms_norm == ["vllm_c", "native"]
|
||||
|
||||
# Assert single ops override the default
|
||||
assert IrOpPriorityConfig.with_default(
|
||||
["vllm_c", "native"], rms_norm=["oink", "native"]
|
||||
) == IrOpPriorityConfig(rms_norm=["oink", "native"])
|
||||
|
||||
|
||||
def test_ir_op_priority_str():
|
||||
"""Test that passing a comma-delimited string works"""
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
|
||||
priority_config = IrOpPriorityConfig(rms_norm="vllm_c")
|
||||
assert priority_config.rms_norm == ["vllm_c"]
|
||||
|
||||
priority_config = IrOpPriorityConfig(rms_norm="vllm_c,native")
|
||||
assert priority_config.rms_norm == ["vllm_c", "native"]
|
||||
|
||||
priority_config = IrOpPriorityConfig(rms_norm=" native, vllm_c ")
|
||||
assert priority_config.rms_norm == ["native", "vllm_c"]
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
# must be list of only strings
|
||||
priority_config = IrOpPriorityConfig(rms_norm=["vllm_c", 4, "native"])
|
||||
|
||||
Reference in New Issue
Block a user