diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml index 8eba8da0b..69abc69b0 100644 --- a/.buildkite/test_areas/kernels.yaml +++ b/.buildkite/test_areas/kernels.yaml @@ -2,6 +2,16 @@ group: Kernels depends_on: - image-build steps: +- label: vLLM IR Tests + timeout_in_minutes: 10 + working_dir: "/vllm-workspace/" + source_file_dependencies: + - vllm/ir + - vllm/kernels + commands: + - pytest -v -s tests/ir + - pytest -v -s tests/kernels/ir + - label: Kernels Core Operation Test timeout_in_minutes: 75 source_file_dependencies: diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6c8ed3d8b..a9c1c1712 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -13,6 +13,9 @@ /vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy /vllm/model_executor/model_loader @22quinn /vllm/model_executor/layers/batch_invariant.py @yewentao256 +/vllm/ir @ProExpertProg +/vllm/kernels/ @ProExpertProg @tjtanaa +/vllm/kernels/helion @ProExpertProg @zou3519 /vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa /vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni CMakeLists.txt @tlrmchlsmth @LucasWilkinson @@ -74,6 +77,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche /tests/evals @mgoin @vadiklyutiy /tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256 +/tests/kernels/ir @ProExpertProg @tjtanaa /tests/models @DarkLight1337 @ywang96 /tests/multimodal @DarkLight1337 @ywang96 @NickLucche /tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety diff --git a/tests/compile/backend.py b/tests/compile/backend.py index ec4685324..d61c128a5 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -8,7 +8,7 @@ from copy import deepcopy import depyf from torch import fx -from torch._ops import OpOverload +from torch._ops import OpOverload, OpOverloadPacket from torch.fx._utils import lazy_format_graph_code from vllm.compilation.passes.fx_utils import find_op_nodes @@ -90,7 +90,9 @@ class TestBackend: # assign by reference, will reflect the final state of the graph self.final_graph = graph - def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True): + def check_before_ops( + self, ops: Sequence[OpOverload | OpOverloadPacket], fully_replaced=True + ): for op in ops: num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) num_post = len(list(find_op_nodes(op, self.graph_post_pass))) @@ -99,13 +101,19 @@ class TestBackend: if fully_replaced: assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph" - def check_after_ops(self, ops: Sequence[OpOverload]): + def check_after_ops(self, ops: Sequence[OpOverload | OpOverloadPacket]): for op in ops: num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) num_post = len(list(find_op_nodes(op, self.graph_post_pass))) assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" assert num_post > 0, f"Op {op.name()} not found in post-pass graph" - def op_count(self, op: OpOverload, before=False) -> int: + def op_count(self, op: OpOverload | OpOverloadPacket, before=False) -> int: graph = self.graph_pre_pass if before else self.graph_post_pass return len(list(find_op_nodes(op, graph))) + + def print_graphs(self): + print("=== Graph before custom passes ===") + print(self.graph_pre_pass.python_code(root_module="self", verbose=True).src) + print("=== Graph after custom passes ===") + print(self.graph_post_pass.python_code(root_module="self", verbose=True).src) diff --git a/tests/compile/fusions_e2e/test_tp1_quant.py b/tests/compile/fusions_e2e/test_tp1_quant.py index 8895dadce..8186ecbb4 100644 --- a/tests/compile/fusions_e2e/test_tp1_quant.py +++ b/tests/compile/fusions_e2e/test_tp1_quant.py @@ -99,6 +99,8 @@ def test_tp1_fp8_fusions( model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["load_format"] = "dummy" model_kwargs["max_model_len"] = 1024 + model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False} + compilation_config = dict( use_inductor_graph_partition=inductor_graph_partition, custom_ops=custom_ops.split(","), @@ -166,6 +168,7 @@ def test_tp1_fp4_fusions( model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["load_format"] = "dummy" model_kwargs["max_model_len"] = 1024 + model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False} compilation_config = dict( use_inductor_graph_partition=inductor_graph_partition, diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index 301409b2b..fa1ceb7f0 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -68,6 +68,7 @@ def test_tp2_ar_rms_fp8_fusions( model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["load_format"] = "dummy" model_kwargs["max_model_len"] = 1024 + model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False} compilation_config = dict( use_inductor_graph_partition=inductor_graph_partition, @@ -128,6 +129,7 @@ def test_tp2_ar_rms_fp4_fusions( model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["load_format"] = "dummy" model_kwargs["max_model_len"] = 1024 + model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False} compilation_config = dict( use_inductor_graph_partition=inductor_graph_partition, @@ -182,6 +184,7 @@ def test_tp2_ar_rms_fusions( model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["load_format"] = "dummy" model_kwargs["max_model_len"] = 1024 + model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False} compilation_config = dict( use_inductor_graph_partition=inductor_graph_partition, diff --git a/tests/compile/fusions_e2e/test_tp2_async_tp.py b/tests/compile/fusions_e2e/test_tp2_async_tp.py index 9657d64b8..609377e68 100644 --- a/tests/compile/fusions_e2e/test_tp2_async_tp.py +++ b/tests/compile/fusions_e2e/test_tp2_async_tp.py @@ -58,6 +58,7 @@ def test_tp2_async_tp_fp8_fusions( model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["load_format"] = "dummy" model_kwargs["max_model_len"] = 1024 + model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False} compilation_config = dict( use_inductor_graph_partition=inductor_graph_partition, @@ -121,6 +122,7 @@ def test_tp2_async_tp_fusions( model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["load_format"] = "dummy" model_kwargs["max_model_len"] = 1024 + model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False} compilation_config = dict( use_inductor_graph_partition=inductor_graph_partition, diff --git a/tests/compile/passes/distributed/test_sequence_parallelism.py b/tests/compile/passes/distributed/test_sequence_parallelism.py index e7bf330cc..8588e0501 100644 --- a/tests/compile/passes/distributed/test_sequence_parallelism.py +++ b/tests/compile/passes/distributed/test_sequence_parallelism.py @@ -9,7 +9,6 @@ from tests.compile.backend import TestBackend from tests.utils import TestFP8Layer, multi_gpu_test from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass -from vllm.compilation.passes.fx_utils import find_auto_fn from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass @@ -86,13 +85,14 @@ class TestAllReduceRMSNormModel(torch.nn.Module): ] def ops_in_model(self): - if RMSNorm.enabled(): - return [ - torch.ops._C.rms_norm.default, + return ( + [torch.ops.vllm_ir.rms_norm] + + [ torch.ops._C.fused_add_rms_norm.default, ] - else: - return [] + if RMSNorm.enabled() + else [] + ) class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): @@ -321,4 +321,4 @@ def sequence_parallelism_pass_on_test_model( assert backend.op_count(op, before=False) == 4 for op in model.ops_in_model(): - find_auto_fn(backend.graph_post_pass.nodes, op) + assert backend.op_count(op, before=False) > 0 diff --git a/tests/compile/passes/ir/__init__.py b/tests/compile/passes/ir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/compile/passes/ir/test_lowering.py b/tests/compile/passes/ir/test_lowering.py new file mode 100644 index 000000000..b7ca55e7d --- /dev/null +++ b/tests/compile/passes/ir/test_lowering.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from torch import nn + +import vllm.kernels # noqa: F401 to register kernels +from vllm import ir +from vllm.compilation.passes.ir.lowering_pass import ( + VllmIRLoweringPass, +) +from vllm.config import get_current_vllm_config +from vllm.ir import ops +from vllm.platforms import current_platform + +from ...backend import TestBackend + + +class Model(nn.Module): + def __init__(self, hidden_size=16, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hidden_size = hidden_size + self.weight = torch.ones(hidden_size, dtype=torch.bfloat16) + + def forward(self, x): + x1 = x + 4.0 + x2 = ops.rms_norm(x1, self.weight, 1e-5) + x3 = x2 * 5.0 + # no weight + x4 = ops.rms_norm(x3, None, 1e-5) + x5 = x4 / 2.0 + # dispatch to native due to variance_size parameter + x6 = ops.rms_norm(x5, self.weight, 1e-5, self.hidden_size // 2) + return x6 + 3.0 + + +@pytest.mark.parametrize("rms_provider", ops.rms_norm.supported_providers()) +def test_lowering_rms_norm(rms_provider, default_vllm_config): + torch.set_default_device(current_platform.device_type) + + lowering_pass = VllmIRLoweringPass(get_current_vllm_config()) + backend = TestBackend(lowering_pass) + backend_unlowered = TestBackend() + + model = Model() + x = torch.randn(8, 16, dtype=torch.bfloat16) + with ( + ops.rms_norm.set_priority([rms_provider, "native"]), + ir.enable_torch_wrap(True), + ): + compiled_model = torch.compile(model, backend=backend, fullgraph=True) + compiled_unlowered_model = torch.compile( + model, backend=backend_unlowered, fullgraph=True + ) + output = compiled_model(x) + output_unlowered = compiled_unlowered_model(x) + + selected = lowering_pass.selected_impls["rms_norm"] + assert len(selected) == 3 + assert selected["rms_norm"] == rms_provider + assert selected["rms_norm_1"] == rms_provider + assert selected["rms_norm_2"] == "native" + + # Compiled function guards on global value, avoid recompilation + with ir.enable_torch_wrap(True): + output2 = compiled_model(x) + + torch.testing.assert_close(output_unlowered, output) + torch.testing.assert_close(output_unlowered, output2) diff --git a/tests/compile/passes/test_fusion.py b/tests/compile/passes/test_fusion.py index 5df9424a5..368ddc8f3 100644 --- a/tests/compile/passes/test_fusion.py +++ b/tests/compile/passes/test_fusion.py @@ -6,6 +6,7 @@ import pytest import torch import vllm.config +import vllm.ir.ops import vllm.plugins from tests.compile.backend import TestBackend from tests.utils import TestBlockFP8Layer, TestFP8Layer @@ -51,7 +52,6 @@ from vllm.utils.deep_gemm import ( FP8_DTYPE = current_platform.fp8_dtype() -RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default # Kernel and group_shape combinations: (kernel, group_shape) @@ -246,10 +246,8 @@ class TestModel(torch.nn.Module): ] def ops_in_model_before_partial(self): - return ( - [RMS_OP, RMS_ADD_OP] - if self.enable_rms_norm_custom_op - else [torch.ops.aten.rsqrt] + return [torch.ops.vllm_ir.rms_norm] + ( + [RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt] ) @@ -340,7 +338,10 @@ def test_fusion_rmsnorm_quant( ), ) - with vllm.config.set_current_vllm_config(vllm_config): + with ( + vllm.config.set_current_vllm_config(vllm_config), + vllm_config.kernel_config.ir_op_priority.set_priority(), + ): # Setup device before model creation torch.set_default_device("cuda") torch.set_default_dtype(dtype) @@ -370,8 +371,9 @@ def test_fusion_rmsnorm_quant( # Hence, we check only 2 add nodes are left (final fused rmsnorm add). if not enable_rms_norm_custom_op: n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) - # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) - assert n_add_nodes(backend.graph_pre_pass) == 7 + # rms_norm is IR, not included + # 6 = 3x2 (3xRMS_ADD, 2 each) + assert n_add_nodes(backend.graph_pre_pass) == 6 assert n_add_nodes(backend.graph_post_pass) == 2 diff --git a/tests/compile/passes/test_qk_norm_rope_fusion.py b/tests/compile/passes/test_qk_norm_rope_fusion.py index f9a86732c..25b8ea56f 100644 --- a/tests/compile/passes/test_qk_norm_rope_fusion.py +++ b/tests/compile/passes/test_qk_norm_rope_fusion.py @@ -3,11 +3,11 @@ import pytest import torch +from torch._ops import OpOverload, OpOverloadPacket from tests.compile.backend import TestBackend from vllm.compilation.passes.fusion.matcher_utils import ( FLASHINFER_ROTARY_OP, - RMS_OP, ROTARY_OP, ) from vllm.compilation.passes.fusion.qk_norm_rope_fusion import ( @@ -100,13 +100,8 @@ class QKNormRoPETestModel(torch.nn.Module): q, k = self.rotary_emb(positions, q, k) return q, k, v - def ops_in_model_before(self) -> list[torch._ops.OpOverload]: - ops = [] - if self.enable_rms_norm_custom_op: - ops.append(RMS_OP) - else: - ops.append(RSQRT_OP) - + def ops_in_model_before(self) -> list[OpOverload | OpOverloadPacket]: + ops: list[OpOverload | OpOverloadPacket] = [torch.ops.vllm_ir.rms_norm] if self.enable_rope_custom_op: if self.rotary_emb.use_flashinfer: ops.append(FLASHINFER_ROTARY_OP) @@ -116,7 +111,7 @@ class QKNormRoPETestModel(torch.nn.Module): ops.append(INDEX_SELECT_OP) return ops - def ops_in_model_after(self) -> list[torch._ops.OpOverload]: + def ops_in_model_after(self) -> list[OpOverload | OpOverloadPacket]: return [FUSED_QK_ROPE_OP] @@ -166,7 +161,10 @@ def test_qk_norm_rope_fusion( num_heads, num_kv_heads, head_dim = 16, 4, 128 T = 5 - with set_current_vllm_config(vllm_config): + with ( + set_current_vllm_config(vllm_config), + vllm_config.kernel_config.ir_op_priority.set_priority(), + ): model = QKNormRoPETestModel( num_heads=num_heads, num_kv_heads=num_kv_heads, diff --git a/tests/conftest.py b/tests/conftest.py index f3b22d898..38f2bc097 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1622,3 +1622,26 @@ def fresh_vllm_cache(monkeypatch, use_fresh_inductor_cache): def enable_pickle(monkeypatch): """`LLM.apply_model` requires pickling a function.""" monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + +@pytest.fixture(scope="function") +def disable_log_dedup(monkeypatch): + """ + Disable log deduplication such that warning_once and info_once always print. + """ + + # Patch logger._print_warning_once to remove the lru_cache decorator + from vllm import logger + + original_print_warning_once = logger._print_warning_once + original_print_info_once = logger._print_info_once + original_print_debug_once = logger._print_debug_once + + logger._print_warning_once = original_print_warning_once.__wrapped__ + logger._print_info_once = original_print_info_once.__wrapped__ + logger._print_debug_once = original_print_debug_once.__wrapped__ + + yield + logger._print_warning_once = original_print_warning_once + logger._print_info_once = original_print_info_once + logger._print_debug_once = original_print_debug_once diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index d1986e0a4..a0f4bf970 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -523,3 +523,20 @@ def test_human_readable_model_len(): for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]: with pytest.raises(ArgumentError): parser.parse_args(["--max-model-len", invalid]) + + +def test_ir_op_priority(): + from vllm.config.kernel import IrOpPriorityConfig, KernelConfig + + ir_op_priority = IrOpPriorityConfig(rms_norm=["vllm_c"]) + cfg1 = EngineArgs(ir_op_priority=ir_op_priority).create_engine_config() + cfg2 = EngineArgs( + kernel_config=KernelConfig(ir_op_priority=ir_op_priority) + ).create_engine_config() + assert cfg1.kernel_config.ir_op_priority == cfg2.kernel_config.ir_op_priority + + with pytest.raises(ValueError, match="rms_norm"): + _ = EngineArgs( + ir_op_priority=ir_op_priority, + kernel_config=KernelConfig(ir_op_priority=ir_op_priority), + ).create_engine_config() diff --git a/tests/ir/test_op.py b/tests/ir/test_op.py new file mode 100644 index 000000000..8d4245a04 --- /dev/null +++ b/tests/ir/test_op.py @@ -0,0 +1,497 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib.util +import logging +from pathlib import Path +from typing import Any + +import pytest +import torch +from torch import fx +from torch.fx.experimental.proxy_tensor import make_fx + +import vllm.ir.op +from vllm.ir.op import RESERVED_PROVIDERS, IrOp, IrOpImpl + +# This should not exist +assert "_custom_add" not in IrOp.registry + + +class CustomError(Exception): + pass + + +@vllm.ir.register_op +def _custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + +def test_registration_overloads(): + assert all( + n not in IrOp.registry for n in ["_custom_sub", "_custom_mul", "_custom_div"] + ) + + # Calling with decorator + @vllm.ir.register_op() + def _custom_sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x - y + + assert _custom_sub.name == "_custom_sub" + assert _custom_sub is IrOp.registry["_custom_sub"] + + # Custom name + @vllm.ir.register_op(name="_custom_mul") + def custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x * y + + assert custom_mul.name == "_custom_mul" + assert custom_mul is IrOp.registry["_custom_mul"] + + # Direct construction does not register directly + def _custom_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x / y + + custom_div = IrOp("_custom_div", _custom_div) + assert custom_div.name == "_custom_div" + assert "_custom_div" not in IrOp.registry + + # Duplicate op registration not allowed + with pytest.raises(AssertionError): + + @vllm.ir.register_op + def _custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x * y - 100 + + +def test_no_kw_only_args(): + # kw-only args not supported + with pytest.raises(ValueError, match="keyword-only arguments"): + + @vllm.ir.register_op + def _custom_kwarg_op( + x: torch.Tensor, y: torch.Tensor, *, kwarg: int = 0 + ) -> torch.Tensor: + return x + y + kwarg + + assert "_custom_kwarg_op" not in IrOp.registry + + +class TestIrOpCustomAdd: + # Registration invariants + def test_decorated_object(self): + """Make sure that referring directly to an op is correct""" + assert isinstance(_custom_add, IrOp) + assert "_custom_add" in IrOp.registry + assert _custom_add is IrOp.registry["_custom_add"] + + def test_torch_op_is_registered(self): + assert hasattr(torch.ops.vllm_ir, "_custom_add") + assert callable(torch.ops.vllm_ir._custom_add.default) + + # Semantic correctness + def test_semantics_match_native(self): + x = torch.randn(4, 5) + y = torch.randn(4, 5) + + # Calls native by default + out = _custom_add(x, y) + ref = x + y + + torch.testing.assert_close(out, ref) + + # ------------------------- + # Implementation registration + # ------------------------- + + def test_register_impl_is_non_intrusive(self): + @_custom_add.register_impl("dummy_provider") + def dummy_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + 123 + + assert "dummy_provider" in _custom_add.impls + assert isinstance(_custom_add.impls["dummy_provider"], IrOpImpl) + + x = torch.ones(2, 2) + y = torch.ones(2, 2) + + # Native semantics must still hold + torch.testing.assert_close(_custom_add(x, y), x + y) + + def test_schema_contains_tensor_signature(self): + schema = _custom_add._schema_str + + assert "Tensor" in schema + assert "-> Tensor" in schema + + # ------------------------- + # FX visibility + # ------------------------- + + @pytest.mark.parametrize("enable_torch_wrap", [True, False]) + @pytest.mark.parametrize("symbolic_trace", [True, False]) + def test_trace_sees_single_custom_op( + self, symbolic_trace: bool, enable_torch_wrap: bool + ): + def fn(x, y): + return _custom_add(x, y) + + def find_fn(target: Any, gm: fx.GraphModule): + return gm.graph.find_nodes(op="call_function", target=target) + + with pytest.raises(CustomError), vllm.ir.enable_torch_wrap(enable_torch_wrap): + if symbolic_trace: + gm = torch.fx.symbolic_trace(fn) + else: + gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2)) + + x1, y1 = torch.rand(5, 4), torch.rand(5, 4) + out_fx = gm(x1, y1) + out_eager = fn(x1, y1) + + # raise error to check enable_torch_wrap context restored correctly + raise CustomError + + # check behavior matches eager in all cases + torch.testing.assert_close(out_fx, out_eager) + + # check that IR nodes only appear if enable_torch_wrap=True + ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm) + if enable_torch_wrap: + assert len(ir_nodes) == 1, gm.code + else: + assert len(ir_nodes) == 0, gm.code + + # with torch wrapping enabled (default), IR nodes appear + if symbolic_trace: + gm = torch.fx.symbolic_trace(fn) + else: + gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2)) + + ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm) + assert len(ir_nodes) == 1, gm.code + + +@_custom_add.register_impl("impl_a") +def impl_a(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + 10 + + +@_custom_add.register_impl("impl_b") +def impl_b(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + 20 + + +@_custom_add.register_impl("impl_even", supports_args=lambda x, y: x.size(1) % 2 == 0) +def impl_even(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + 50 + + +class TestIrOpImplDispatch: + def test_register_impl(self): + assert "impl_a" in _custom_add.impls + impl = _custom_add.impls["impl_a"] + + assert impl is impl_a + assert impl.op is _custom_add + assert impl.provider == "impl_a" + assert callable(impl.impl_fn) + + # Test duplicate registration rejected + with pytest.raises(AssertionError): + + @_custom_add.register_impl("impl_a") + def impl_a_dup(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + 30 + + # Check the original impl is still intact + assert _custom_add.impls["impl_a"] is impl_a + + # Check support all args + assert impl_a.supports_all_args + assert impl_b.supports_all_args + assert not impl_even.supports_all_args + + def test_reserved_provider_rejected(self): + for provider in RESERVED_PROVIDERS: + with pytest.raises(AssertionError): + + @_custom_add.register_impl(provider) + def bad_impl(x, y): + return x + y + + def test_set_priority_scoped(self): + assert _custom_add.get_priority() == [] + + with _custom_add.set_priority(["impl_even", "impl_b"]): + assert _custom_add.get_priority() == ["impl_even", "impl_b"] + + # Check nesting + with _custom_add.set_priority(["impl_b"]): + assert _custom_add.get_priority() == ["impl_b"] + + # Restored + assert _custom_add.get_priority() == ["impl_even", "impl_b"] + + # Check that exception restores priority + with pytest.raises(CustomError), _custom_add.set_priority(["impl_a"]): + assert _custom_add.get_priority() == ["impl_a"] + raise CustomError + + # Restored again + assert _custom_add.get_priority() == ["impl_even", "impl_b"] + + # Restored to empty + assert _custom_add.get_priority() == [] + + def test_dispatch_priority_order(self): + x = torch.tensor(1, dtype=torch.int32) + y = torch.tensor(2, dtype=torch.int32) + + with _custom_add.set_priority(["impl_b", "impl_a"]): + assert _custom_add.dispatch(x, y) is impl_b + out1 = _custom_add(x, y) + out2 = torch.ops.vllm_ir._custom_add(x, y) + + with _custom_add.set_priority(["impl_a"]): + assert _custom_add.dispatch(x, y) is impl_a + out3 = _custom_add(x, y) + out4 = torch.ops.vllm_ir._custom_add(x, y) + + # impl_b + assert out1.item() == 1 + 2 + 20 + assert out2.item() == 1 + 2 + 20 + # impl_a + assert out3.item() == 1 + 2 + 10 + assert out4.item() == 1 + 2 + 10 + + def test_unsupported_impl_filtered(self): + @_custom_add.register_impl("unsupported", supported=False) + def impl_bad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + 999 + + x = torch.tensor(1, dtype=torch.int32) + y = torch.tensor(2, dtype=torch.int32) + + with _custom_add.set_priority(["unsupported", "impl_a"]): + assert _custom_add.get_priority() == ["impl_a"] + out = _custom_add(x, y) + + # impl_bad skipped → impl_a + assert out.item() == 1 + 2 + 10 + + def test_supports_args_runtime_dispatch_and_warning( + self, caplog_vllm: pytest.LogCaptureFixture + ): + x1 = torch.ones((2, 2), dtype=torch.int32) + y1 = torch.full((2, 2), 2, dtype=torch.int32) + + x2 = torch.ones((2, 3), dtype=torch.int32) + y2 = torch.full((2, 3), 2, dtype=torch.int32) + + with ( + caplog_vllm.at_level(logging.WARNING), + _custom_add.set_priority(["impl_even"]), + ): + # Test the warning about native fallback is logged (before even dispatching) + assert len(caplog_vllm.records) == 1 + message = caplog_vllm.records[0].message + assert "_custom_add" in message + assert "fallback to native" in message + assert "priority" in message + + # Check dispatching + assert _custom_add.get_priority() == ["impl_even", "native"] + assert _custom_add.dispatch(x1, y1) is impl_even + assert _custom_add.dispatch(x2, y2) is _custom_add.impls["native"] + + out1 = _custom_add(x1, y1) # size(1) == 2 → impl_even + out2 = _custom_add(x2, y2) # size(1) == 3 → native fallback + + # no other warnings + assert len(caplog_vllm.records) == 1 + assert torch.all(out1 == 1 + 2 + 50) + assert torch.all(out2 == 1 + 2) + + def test_default_priority( + self, caplog_vllm: pytest.LogCaptureFixture, disable_log_dedup + ): + # Make sure logs are not deduplicated to properly test the warning + x = torch.tensor([3], dtype=torch.int32) + y = torch.tensor([4], dtype=torch.int32) + + # No priority set → falls back to native + assert _custom_add.get_priority() == [] + with caplog_vllm.at_level(logging.WARNING): + # Native by default + assert _custom_add.dispatch(x, y) is _custom_add.impls["native"] + out = _custom_add(x, y) + + # Check dispatching to native by default + assert out.item() == 3 + 4 + + # Check warning + assert len(caplog_vllm.records) == 2 + message = caplog_vllm.records[0].message.lower() + assert "_custom_add" in message + assert "priority not set" in message + + +@vllm.ir.register_op +def _custom_mm( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None +) -> torch.Tensor: + tmp = x @ y + return tmp if bias is None else tmp + bias + + +def test_default_args(): + # Test that default args are properly applied when dispatching and calling + @_custom_mm.register_impl("impl_mm", supports_args=lambda x, y, bias=None: True) + def impl_mm( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + tmp = x @ y + return tmp + 50 if bias is None else tmp + bias + 100 + + x1 = torch.tensor([1, 2], dtype=torch.int32) + x2 = torch.tensor([3, 4], dtype=torch.int32) + + # Test that supports_args receives the defaulted args + assert impl_mm.supports_args(x1, x2) + with _custom_mm.set_priority(["impl_mm", "native"]): + assert _custom_mm.dispatch(x1, x2) is impl_mm + + +def test_bad_impl_registrations(): + # Check bad schema + with pytest.raises(ValueError, match="does not match native schema"): + + @_custom_mm.register_impl("impl_mm_bad_schema") + def impl_mm_bad_schema(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x @ y - 1 + + with pytest.raises(ValueError, match="does not match native schema"): + + @_custom_mm.register_impl("impl_mm_bad_schema_2") + def impl_mm_bad_schema_2( + x: torch.Tensor, y: torch.Tensor, b: torch.Tensor | None = None + ) -> torch.Tensor: + return x @ y + b - 2 + + with pytest.raises(ValueError, match="does not match native schema"): + + @_custom_mm.register_impl("impl_mm_bad_schema_3") + def impl_mm_bad_schema_3( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + return x @ y + bias - 5 + + # check supports_args with incorrect params + with pytest.raises(ValueError, match="supports_args must be a callable"): + + @_custom_mm.register_impl("impl_mm_bad_supports_args", supports_args=True) + def impl_mm_bad_supports_args( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return x @ y + 10 + + with pytest.raises(ValueError, match="number of parameters"): + + @_custom_mm.register_impl( + "impl_mm_bad_supports_args_2", supports_args=lambda x, y: True + ) + def impl_mm_bad_supports_args( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return x @ y + 10 + + with pytest.raises(ValueError, match="keyword-only parameters"): + + @_custom_mm.register_impl( + "impl_mm_bad_supports_args_3", supports_args=lambda x, y, *, b: True + ) + def impl_mm_bad_supports_args_2( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return x @ y + 20 + + with pytest.raises(ValueError, match="does not match native parameter"): + + @_custom_mm.register_impl( + "impl_mm_bad_supports_args_4", supports_args=lambda x, y, b: True + ) + def impl_mm_bad_supports_args_4( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return x @ y + 30 + + with pytest.raises(ValueError, match="does not match native default"): + + @_custom_mm.register_impl( + "impl_mm_bad_supports_args_5", supports_args=lambda x, y, bias=1: True + ) + def impl_mm_bad_supports_args_5( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return x @ y + 40 + + assert set(_custom_mm.impls.keys()) == {"impl_mm", "native"} + + +IMPL_OOT_SRC = """ +import torch + +@_custom_mm.register_impl("impl_mm_oot") +def impl_mm_oot( + x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None +) -> torch.Tensor: + return x @ y - 99 +""" + + +def load_custom_mm_module(file_path: Path): + spec = importlib.util.spec_from_file_location("_custom_mm_oot", file_path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + + # Inject the variable into the module's global namespace + # This allows the @_custom_mm.register_impl decorator to work + module._custom_mm = _custom_mm # type: ignore[attr-defined] + + # Execute the file; this triggers the decorator + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def test_uuid_and_oot(tmp_path: Path): + file_path = tmp_path / "_custom_mm_oot.py" + file_path.write_text(IMPL_OOT_SRC) + + assert "impl_mm_oot" not in _custom_mm.impls + _ = load_custom_mm_module(file_path) + assert "impl_mm_oot" in _custom_mm.impls + + uuid = _custom_mm.impls["impl_mm_oot"].uuid() + del _custom_mm.impls["impl_mm_oot"] + + # Replace file source + file_path.write_text(IMPL_OOT_SRC + " # added file source") + assert "impl_mm_oot" not in _custom_mm.impls + _ = load_custom_mm_module(file_path) + assert "impl_mm_oot" in _custom_mm.impls + + uuid1 = _custom_mm.impls["impl_mm_oot"].uuid() + assert uuid1 != uuid + del _custom_mm.impls["impl_mm_oot"] + + # Back to original + file_path.write_text(IMPL_OOT_SRC) + assert "impl_mm_oot" not in _custom_mm.impls + _ = load_custom_mm_module(file_path) + assert "impl_mm_oot" in _custom_mm.impls + + uuid2 = _custom_mm.impls["impl_mm_oot"].uuid() + assert uuid2 == uuid + assert uuid2 != uuid1 + del _custom_mm.impls["impl_mm_oot"] diff --git a/tests/kernels/ir/test_layernorm.py b/tests/kernels/ir/test_layernorm.py new file mode 100644 index 000000000..3d2116909 --- /dev/null +++ b/tests/kernels/ir/test_layernorm.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +# This registers op implementations +import vllm.kernels # noqa: F401 +from tests.kernels.allclose_default import get_default_rtol +from vllm import ir +from vllm.platforms import current_platform + + +def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype): + x = torch.randn(n_tokens, hidden_size, dtype=dtype) + weight = torch.rand(hidden_size, dtype=dtype) + return x, weight + + +rms_norm_native = ir.ops.rms_norm.impls["native"].impl_fn + + +@pytest.mark.skipif( + not current_platform.is_cuda_alike() and not current_platform.is_xpu(), + reason="Currently only kernels on CUDA, ROCm and XPU", +) +def test_rms_norm_registration(): + expected = { + "native": True, + "vllm_c": current_platform.is_cuda_alike(), + "aiter": current_platform.is_rocm(), + "oink": False, + "xpu_kernels": current_platform.is_xpu(), + } + + actual = { + provider: impl.supported for provider, impl in ir.ops.rms_norm.impls.items() + } + + assert actual == expected + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("n_tokens", [1, 8, 17]) +@pytest.mark.parametrize("hidden_size", [16, 4096, 8192]) +@pytest.mark.parametrize("epsilon", [1e-6, 1e-5]) +@pytest.mark.skipif( + not current_platform.is_cuda_alike() and not current_platform.is_xpu(), + reason="Currently only kernels on CUDA, ROCm and XPU", +) +class TestRMSNorm: + @classmethod + def setup_class(cls, **kwargs): + torch.set_default_device(current_platform.device_type) + + def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon): + x, weight = rms_norm_inputs(4, 8, dtype) + out = rms_norm_native(x, weight, epsilon=epsilon) + + # Check shape, dtype, device + assert out.shape == x.shape + assert out.dtype == x.dtype + assert out.device == x.device + + # Check the scaling property of rms norm + out2 = rms_norm_native(x * 2.0, weight, epsilon=epsilon) + torch.testing.assert_close(out2, out, rtol=get_default_rtol(out), atol=1e-3) + + # Check behavior with and without weight + weight1 = torch.ones_like(weight) + out3 = rms_norm_native(x, weight1, epsilon=epsilon) + out4 = rms_norm_native(x, None, epsilon=epsilon) + torch.testing.assert_close(out3, out4) + + @pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels"]) + def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider): + impl = ir.ops.rms_norm.impls[provider] + if not impl.supported: + pytest.skip(f"{provider} impl not supported on this platform") + + x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype) + args = (x, weight, epsilon, None) + + assert impl.supported + + if provider == "aiter" and dtype not in [torch.float16, torch.bfloat16]: + assert not impl.supports_args(*args) + return + + assert impl.supports_args(*args) + + out_impl = impl.impl_fn(*args) + out_native = rms_norm_native(*args) + + torch.testing.assert_close( + out_impl, out_native, rtol=get_default_rtol(out_impl), atol=1e-3 + ) + + # check that dispatched call matches direct call + with ir.ops.rms_norm.set_priority([provider, "native"]): + out_impl2 = ir.ops.rms_norm(*args) + + # exact match + torch.testing.assert_close(out_impl2, out_impl, rtol=0.0, atol=0.0) + + # none of these support variance_size override + assert not impl.supports_args(x, weight, epsilon, 4) + assert not impl.supports_args(x, weight, epsilon, variance_size=4) + + # test weight=None behavior + out_impl_no_weight = impl.impl_fn(x, None, epsilon) + out_impl_unit_weight = impl.impl_fn(x, torch.ones_like(weight), epsilon) + torch.testing.assert_close( + out_impl_no_weight, + out_impl_unit_weight, + rtol=get_default_rtol(out_impl_no_weight), + atol=2e-4, + ) + + @pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels", "native"]) + def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider): + if not ir.ops.rms_norm.impls[provider].supported: + pytest.skip(f"{provider} impl not supported on this platform") + + x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype) + args = (x, weight, epsilon, None) + + # When checking the torch op, we have to set priority and use dispatch + with ir.ops.rms_norm.set_priority([provider, "native"]): + torch.library.opcheck(torch.ops.vllm_ir.rms_norm, args) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 36d7f5cc4..fc4f6f6b6 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_rocm_rmsnorm_func, fused_add_rms_norm, - rms_norm, ) from vllm.platforms import current_platform @@ -156,7 +155,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool): assert topk_func == vllm_topk_sigmoid -@pytest.mark.parametrize("add_residual", [True, False]) +@pytest.mark.parametrize("add_residual", [False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_rocm_aiter", [True, False]) @pytest.mark.skipif( @@ -165,7 +164,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool): def test_rms_norm_dispatch( add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool ): - rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter) + rms_norm_func = dispatch_rocm_rmsnorm_func(dtype, use_rocm_aiter) should_use_rocm_aiter = ( current_platform.is_rocm() @@ -173,11 +172,7 @@ def test_rms_norm_dispatch( and dtype in RMS_NORM_SUPPORTED_DTYPES ) - if add_residual and should_use_rocm_aiter: + if should_use_rocm_aiter: assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add - elif should_use_rocm_aiter: - assert rms_norm_func == rocm_aiter_ops.rms_norm - elif add_residual: - assert rms_norm_func == fused_add_rms_norm else: - assert rms_norm_func == rms_norm + assert rms_norm_func == fused_add_rms_norm diff --git a/tests/test_config.py b/tests/test_config.py index f07a649ca..312f4ce5a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,12 +6,14 @@ import os from dataclasses import MISSING, Field, asdict, dataclass, field from unittest.mock import patch +import pydantic import pytest from pydantic import ValidationError from vllm.compilation.backends import VllmBackend from vllm.config import ( CompilationConfig, + KernelConfig, ModelConfig, ParallelConfig, PoolerConfig, @@ -21,6 +23,7 @@ from vllm.config import ( update_config, ) from vllm.config.compilation import CompilationMode, CUDAGraphMode +from vllm.config.kernel import IrOpPriorityConfig from vllm.config.load import LoadConfig from vllm.config.utils import get_field from vllm.config.vllm import ( @@ -1077,6 +1080,39 @@ def test_vllm_config_explicit_overrides(): assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE +def test_fusion_pass_op_priority(): + """This test checks that custom op enablement & IR op priority + correctly control default fusions""" + + # Default config, O2, rms_norm+quant fusion disabled + cfg1 = VllmConfig() + assert not cfg1.compilation_config.pass_config.fuse_norm_quant + + # rms_norm manually enabled, O1, rms_norm+quant fusion enabled + cfg2 = VllmConfig( + optimization_level=OptimizationLevel.O1, + compilation_config=CompilationConfig( + custom_ops=["+rms_norm"], + ), + ) + assert cfg2.compilation_config.pass_config.fuse_norm_quant + + # using custom kernel for RMSNorm via IR: + # Note that vLLM IR only supports the non-residual rms_norm for now; + # soon this will be resolved. + cfg3 = VllmConfig( + kernel_config=KernelConfig( + ir_op_priority=IrOpPriorityConfig(rms_norm=["vllm_c"]) + ) + ) + assert cfg3.compilation_config.pass_config.fuse_norm_quant + + # block-fp8 model should enable quant_fp8 automatically + cfg4 = VllmConfig(model_config=ModelConfig("Qwen/Qwen3-4B-FP8")) + assert "+quant_fp8" in cfg4.compilation_config.custom_ops + assert cfg4.compilation_config.pass_config.fuse_norm_quant + + def test_scheduler_config_init(): with pytest.raises(ValidationError): # Positional InitVars missing @@ -1171,3 +1207,35 @@ def test_eagle_draft_model_config(): assert draft_model_config.hf_text_config.model_type == "eagle" assert draft_model_config.architectures == ["EagleLlamaForCausalLM"] assert draft_model_config.architecture == "EagleLlamaForCausalLM" + + +def test_ir_op_priority_default(): + """Test that IR op priority defaults are set correctly.""" + from vllm.config.kernel import IrOpPriorityConfig + + # Assert default is applied to ops + priority_config = IrOpPriorityConfig.with_default(["vllm_c", "native"]) + assert priority_config.rms_norm == ["vllm_c", "native"] + + # Assert single ops override the default + assert IrOpPriorityConfig.with_default( + ["vllm_c", "native"], rms_norm=["oink", "native"] + ) == IrOpPriorityConfig(rms_norm=["oink", "native"]) + + +def test_ir_op_priority_str(): + """Test that passing a comma-delimited string works""" + from vllm.config.kernel import IrOpPriorityConfig + + priority_config = IrOpPriorityConfig(rms_norm="vllm_c") + assert priority_config.rms_norm == ["vllm_c"] + + priority_config = IrOpPriorityConfig(rms_norm="vllm_c,native") + assert priority_config.rms_norm == ["vllm_c", "native"] + + priority_config = IrOpPriorityConfig(rms_norm=" native, vllm_c ") + assert priority_config.rms_norm == ["native", "vllm_c"] + + with pytest.raises(pydantic.ValidationError): + # must be list of only strings + priority_config = IrOpPriorityConfig(rms_norm=["vllm_c", 4, "native"]) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index d55b30599..86cdd7c5e 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -3,6 +3,7 @@ import contextlib from importlib.util import find_spec from types import ModuleType +from typing import Any import torch import torch._inductor.pattern_matcher as pm @@ -10,6 +11,7 @@ import torch.fx as fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass +import vllm.ir.ops from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce @@ -28,7 +30,7 @@ from vllm.utils.torch_utils import ( from ..inductor_pass import enable_fake_mode from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass -from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8 FP8_DTYPE = current_platform.fp8_dtype() @@ -258,6 +260,12 @@ class BasePattern: self.tp = get_tp_group() self.tp_size = get_tensor_model_parallel_world_size() + def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor: + return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs) + + def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor: + return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs) + class AllReduceRMSNormPattern(BasePattern): """ @@ -276,20 +284,17 @@ class AllReduceRMSNormPattern(BasePattern): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params - self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def get_inputs(self) -> list[torch.Tensor]: - input, weight = self.rmsnorm_matcher.inputs() - - # input goes through allreduce first, always 16-bit - return [input.to(self.dtype), weight] + # input, weight + return [self.empty(5, 16), self.empty(16)] def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: allreduce_output = tensor_model_parallel_all_reduce(input) - rms = self.rmsnorm_matcher(allreduce_output, weight) + rms = vllm.ir.ops.rms_norm(allreduce_output, weight, self.epsilon) return rms, allreduce_output @@ -407,15 +412,13 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn - self.rmsnorm_matcher = MatcherRMSNorm(epsilon) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def get_inputs(self) -> list[torch.Tensor]: - input, weight = self.rmsnorm_matcher.inputs() _, scale = self.quant_matcher.inputs() - # input goes through allreduce first, always 16-bit - return [input.to(self.dtype), weight, scale] + # input, weight + return [self.empty(5, 16), self.empty(16), scale] def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( @@ -424,7 +427,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = tensor_model_parallel_all_reduce(input) - rms = self.rmsnorm_matcher(all_reduce, weight) + rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon) quant, _ = self.quant_matcher(rms, scale) return quant, all_reduce @@ -553,7 +556,6 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params - self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def get_inputs(self) -> list[torch.Tensor]: input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) @@ -575,7 +577,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): output_scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: all_reduce = tensor_model_parallel_all_reduce(input) - rms = self.rmsnorm_matcher(all_reduce, weight) + rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, input=rms, diff --git a/vllm/compilation/passes/fusion/matcher_utils.py b/vllm/compilation/passes/fusion/matcher_utils.py index ec36c12d1..c2490d8a2 100644 --- a/vllm/compilation/passes/fusion/matcher_utils.py +++ b/vllm/compilation/passes/fusion/matcher_utils.py @@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform -RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default ROTARY_OP = torch.ops._C.rotary_embedding.default FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default @@ -160,69 +159,6 @@ class MatcherRotaryEmbedding(MatcherCustomOp): return result -class MatcherRMSNorm(MatcherCustomOp): - def __init__( - self, - epsilon: float, - enabled: bool | None = None, - match_rocm_aiter: bool = False, - ) -> None: - if enabled is None: - enabled = RMSNorm.enabled() - - super().__init__(enabled) - self.epsilon = epsilon - self._rmsnorm_op = RMS_OP - self.match_rocm_aiter = match_rocm_aiter - - if match_rocm_aiter: - self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op() - - def inputs(self) -> list[torch.Tensor]: - input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) - weight = self.empty(16) - return [input, weight] - - def forward_rocm_aiter( - self, - input: torch.Tensor, - weight: torch.Tensor, - ) -> torch.Tensor: - return self._rmsnorm_op( - x=input, - weight=weight, - variance_epsilon=self.epsilon, - ) - - def forward_custom( - self, - input: torch.Tensor, - weight: torch.Tensor, - ) -> torch.Tensor: - if self.match_rocm_aiter: - return self.forward_rocm_aiter(input, weight) - - result = torch.empty_like(input) - _, result = auto_functionalized( - self._rmsnorm_op, - result=result, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - - return result - - def forward_native( - self, - input: torch.Tensor, - weight: torch.Tensor, - ) -> torch.Tensor: - return RMSNorm.forward_static( - input, self.epsilon, input.size(-1), self.model_dtype, weight - ) - - class MatcherFusedAddRMSNorm(MatcherCustomOp): def __init__( self, diff --git a/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py b/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py index dd1f8245e..245119fa5 100644 --- a/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py +++ b/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py @@ -10,6 +10,7 @@ from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass +import vllm.ir.ops from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -17,7 +18,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from ..inductor_pass import enable_fake_mode from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass -from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding +from .matcher_utils import MatcherRotaryEmbedding from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64 logger = init_logger(__name__) @@ -64,7 +65,6 @@ class QkNormRopePattern: self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.eps = eps - self.rmsnorm_matcher = MatcherRMSNorm(eps) self.is_neox = is_neox self.rope_flashinfer = rope_flashinfer self.rope_matcher = MatcherRotaryEmbedding( @@ -129,14 +129,14 @@ class QkNormRopePattern: q_by_head = q.view( *q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim ) - q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight) + q_normed_by_head = vllm.ir.ops.rms_norm(q_by_head, q_weight, self.eps) q_flat = q_normed_by_head.view(q.shape) # K path: view -> RMS -> view back to k.shape k_by_head = k.view( *k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim ) - k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight) + k_normed_by_head = vllm.ir.ops.rms_norm(k_by_head, k_weight, self.eps) k_flat = k_normed_by_head.view(k.shape) # RoPE: apply to flattened q/k diff --git a/vllm/compilation/passes/fusion/rms_quant_fusion.py b/vllm/compilation/passes/fusion/rms_quant_fusion.py index 95ce7b22e..0e5121c78 100644 --- a/vllm/compilation/passes/fusion/rms_quant_fusion.py +++ b/vllm/compilation/passes/fusion/rms_quant_fusion.py @@ -9,6 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload +import vllm.ir.ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -30,7 +31,6 @@ from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .matcher_utils import ( MatcherFusedAddRMSNorm, MatcherQuantFP8, - MatcherRMSNorm, ) logger = init_logger(__name__) @@ -54,7 +54,6 @@ def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor: return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda") -RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { @@ -131,11 +130,9 @@ class RMSNormQuantPattern: assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] - self.rmsnorm_matcher = ( - MatcherRMSNorm(epsilon) - if not key.fused_add - else MatcherFusedAddRMSNorm(epsilon) - ) + if key.fused_add: + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8( key.quant, has_col_major_scales=has_col_major_scales, @@ -161,16 +158,12 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): def pattern( input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor ) -> torch.Tensor: - result_rms = self.rmsnorm_matcher(input, weight) + result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon) return self.quant_matcher(result_rms, scale)[0] def replacement( input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor ) -> torch.Tensor: - # In case we're matching native rms-norm, conversions might be - # optimized out. We convert here just to be safe. - input = input.to(dtype=self.model_dtype) - result = torch.empty( input.shape, device=input.device, dtype=self.quant_dtype ) @@ -187,8 +180,8 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): return at[1] inputs = [ - # input, weight - *self.rmsnorm_matcher.inputs(), + empty_bf16(5, 16), # input + empty_bf16(16), # weight self.quant_matcher.inputs()[1], # scale ] pattern(*inputs) @@ -391,7 +384,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): def pattern( input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - result_rms = self.rmsnorm_matcher(input, weight) + result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon) result = torch.empty( result_rms.shape, device=result_rms.device, @@ -442,12 +435,14 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): # result, scale return at[1], at[2] - scale = self.quant_matcher.empty_f32(1, 1) - pm.register_replacement( pattern, replacement, - self.rmsnorm_matcher.inputs() + [scale], + [ + empty_bf16(5, 16), # input + empty_bf16(16), # weight + self.quant_matcher.empty_f32(1, 1), # scale + ], pm.fwd_only, pm_pass, ) @@ -472,7 +467,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): def pattern( input: torch.Tensor, weight: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - result_rms = self.rmsnorm_matcher(input, weight) + result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon) # result, scale return self.quant_matcher(result_rms) # type: ignore[no-any-return] @@ -502,7 +497,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): pm.register_replacement( pattern, replacement, - self.rmsnorm_matcher.inputs(), + [ + empty_bf16(5, 16), # input + empty_bf16(16), # weight + ], pm.fwd_only, pm_pass, ) diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py index 59c94db5e..9a9854723 100644 --- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any import torch import torch._inductor.pattern_matcher as pm @@ -24,7 +25,6 @@ from .act_quant_fusion import ActivationQuantPattern from .matcher_utils import ( MatcherFusedAddRMSNorm, MatcherQuantFP8, - MatcherRMSNorm, MatcherSiluAndMul, ) from .rms_quant_fusion import ( @@ -41,17 +41,23 @@ class AiterRMSNormQuantPattern: ): self.epsilon = epsilon self.quant_dtype = key.quant.dtype + self.device = torch.device("cuda") - self.rmsnorm_matcher = ( - MatcherRMSNorm(epsilon, match_rocm_aiter=True) - if not key.fused_add - else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True) - ) + if key.fused_add: + self.rmsnorm_matcher = MatcherFusedAddRMSNorm( + epsilon, match_rocm_aiter=True + ) self.quant_matcher = MatcherQuantFP8( key.quant, match_rocm_aiter=match_aiter_quant, ) + def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor: + return torch.empty(*args, dtype=torch.bfloat16, device=self.device, **kwargs) + + def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor: + return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs) + class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): """AITER RMSNorm + Dynamic Quantization pattern.""" @@ -79,7 +85,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): input: torch.Tensor, weight: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - result_rms = self.rmsnorm_matcher(input, weight) + result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon) result, scale = self.quant_matcher(result_rms) return result, scale @@ -99,7 +105,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): pm.register_replacement( pattern, replacement, - self.rmsnorm_matcher.inputs(), + # input, weight + [self.empty(5, 16), self.empty(16)], pm.fwd_only, pm_pass, ) @@ -188,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): input: torch.Tensor, weight: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - result_rms = self.rmsnorm_matcher(input, weight) + result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon) result, scale = self.quant_matcher(result_rms) return result, scale @@ -206,7 +213,12 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): return at[0], at[1] pm.register_replacement( - pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass + pattern, + replacement, + # input, weight + [self.empty(5, 16), self.empty(16)], + pm.fwd_only, + pm_pass, ) diff --git a/vllm/compilation/passes/fusion/sequence_parallelism.py b/vllm/compilation/passes/fusion/sequence_parallelism.py index b7ae3dc62..e3cfb8a48 100644 --- a/vllm/compilation/passes/fusion/sequence_parallelism.py +++ b/vllm/compilation/passes/fusion/sequence_parallelism.py @@ -10,6 +10,7 @@ import torch._inductor.pattern_matcher as pm import torch.fx as fx from torch._inductor.pattern_matcher import PatternMatcherPass +import vllm.ir.ops from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce @@ -22,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from ..inductor_pass import enable_fake_mode from ..utility.noop_elimination import NoOpEliminationPass from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass -from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8 logger = init_logger(__name__) @@ -122,35 +123,38 @@ class _SequenceParallelPatternHelper: x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name ) + def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor: + return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs) + + def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor: + return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs) + class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None: super().__init__(epsilon, dtype, device) - self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def get_inputs(self) -> list[torch.Tensor]: - input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) - - return [input, arg3_1] + # input, weight + return [self.empty([1, 8, 4]), self.empty([4])] def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, - arg3_1: torch.Tensor, + weight: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(input) - rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1) + rmsnorm = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon) return rmsnorm, all_reduce def replacement( input: torch.Tensor, - arg3_1: torch.Tensor, + weight: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(input) - rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1) + rmsnorm = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon) all_gather = self._all_gather(rmsnorm) return all_gather, reduce_scatter @@ -222,14 +226,11 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): device: str | None, ) -> None: super().__init__(epsilon, dtype, device) - self.rmsnorm_matcher = MatcherRMSNorm(epsilon) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def get_inputs(self) -> list[torch.Tensor]: - input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4], device=self.device, dtype=self.dtype) - scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, weight, scale] + # input, weight, scale + return [self.empty([1, 8, 4]), self.empty([4]), self.empty_f32([1, 1])] def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( @@ -238,7 +239,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(input) - rms = self.rmsnorm_matcher(all_reduce, weight) + rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon) quant, _ = self.quant_matcher(rms, scale) return quant, all_reduce @@ -248,7 +249,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(input) - rms = self.rmsnorm_matcher(reduce_scatter, weight) + rms = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon) quant, _ = self.quant_matcher(rms, scale) all_gather = self._all_gather(quant) diff --git a/vllm/compilation/passes/ir/__init__.py b/vllm/compilation/passes/ir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/compilation/passes/ir/lowering_pass.py b/vllm/compilation/passes/ir/lowering_pass.py new file mode 100644 index 000000000..474a09ae2 --- /dev/null +++ b/vllm/compilation/passes/ir/lowering_pass.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict +from collections.abc import Iterable + +from torch import fx +from torch._inductor.pattern_matcher import ( + CallFunctionVarArgs, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from torch._ops import OpOverload, OpOverloadPacket + +from vllm.config import VllmConfig +from vllm.ir.op import IrOp +from vllm.logger import init_logger +from vllm.logging_utils import lazy + +from ..vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +def get_default_overload(op: OpOverload | OpOverloadPacket) -> OpOverload: + if isinstance(op, OpOverloadPacket): + return op.default + assert isinstance(op, OpOverload), "Expected an OpOverload or OpOverloadPacket" + return op + + +def get_ir_op(node: fx.Node) -> IrOp | None: + if node.op != "call_function": + return None + + if not isinstance(node.target, (OpOverload, OpOverloadPacket)): + return None + + op_overload = get_default_overload(node.target) + if op_overload.namespace != "vllm_ir": + return None + + op_name = op_overload._opname + if op_name not in IrOp.registry: + logger.warning( + "Unknown vLLM IR op %s, there's likely an issue with torch registration, " + "or a torch custom op was registered in the vllm_ir namespace by mistake.", + op_name, + ) + return None + + ir_op = IrOp.registry[op_name] + return ir_op + + +class VllmIRLoweringPass(VllmInductorPass): + """ + This pass lowers vLLM IR ops to their implementations the priority list. + """ + + def __init__(self, vllm_config: VllmConfig) -> None: + super().__init__(vllm_config) + self.patterns = PatternMatcherPass(self.pass_name) + self.selected_impls: dict[str, dict[str, str]] = defaultdict(lambda: {}) + self.ops = [ir_op.torch_op for ir_op in IrOp.registry.values()] + + # Look for any call_function node where the target is a vLLM IR op. + # Then, lower_matched_op will select, trace, and insert the implementation. + register_graph_pattern( + CallFunctionVarArgs(self.ops), + pass_dict=self.patterns, + )(self.lower_matched_op) + + def lower_matched_op(self, match: Match, *args, **kwargs): + # TODO(luka) I think args and kwargs are for the match, but just use the node? + + assert len(match.nodes) == 1, "Expected single node match" + node = match.nodes[0] + ir_op = get_ir_op(node) + assert ir_op is not None, "Expected vLLM IR op" + assert not node.kwargs # I think there should never be kwargs here + + # Select and record the implementation, using fake args + fake_args = fx.map_arg(node.args, lambda arg: arg.meta["val"]) + ir_op_impl = ir_op.dispatch(*fake_args) + self.selected_impls[ir_op.name][node.name] = ir_op_impl.provider + + # replace_by_example wants node args, not the fake tensors + # TODO(luka): Use aot_export_module to get functionalized graph + # TODO(luka): Cache the fx_replacement to avoid re-tracing the same impl + + # Defaults not present on node.args but required for replacement tracing + bound_args = ir_op._py_signature.bind(*node.args) + bound_args.apply_defaults() + match.replace_by_example(ir_op_impl.impl_fn, bound_args.args) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + # clear at the beginning instead of end, so that tests can inspect + self.selected_impls.clear() + + count = self.patterns.apply(graph) + logger.debug("VllmIRLoweringPass lowered %d vLLM IR nodes", count) + + # TODO write self.selected_impls to depyf/tlparse dir + def count_items(impls: Iterable[str]) -> dict[str, int]: + counts: dict[str, int] = defaultdict(lambda: 0) + for impl in impls: + counts[impl] += 1 + return counts + + def print_count(counts: dict[str, int]) -> str: + # e.g., "impl1*3,impl2" + impl_count = lambda i, c: f"{i}" if c == 1 else f"{i}*{c}" + return ",".join(impl_count(impl, count) for impl, count in counts.items()) + + logger.debug( + "Selected implementations: %s", + lazy( + lambda: ", ".join( + f"{op}={print_count(count_items(impls_by_node.values()))}" + for op, impls_by_node in self.selected_impls.items() + ) + ), + ) + + failed_nodes: list[fx.Node] = [] + failed_ops: set[str] = set() + # Check no vllm_ir nodes were left in the graph + for node in graph.nodes: + if (ir_op := get_ir_op(node)) is None: + continue + + failed_nodes.append(node) + failed_ops.add(ir_op.name) + + if failed_nodes or failed_ops: + logger.warning("Failed to lower vLLM IR ops: %s", ",".join(failed_ops)) + logger.warning("Full node list: %s", failed_nodes) + + def uuid(self) -> str: + """ + IR op priority & impl sources affect lowering pass output, + so we include them in the cache key. + """ + priorities = {name: op.get_priority() for name, op in IrOp.registry.items()} + priorities_str = ";".join( + f"{name}={','.join(p)}" for name, p in priorities.items() + ) + + impl_uuids_str = ";".join( + f"{name}={ + ','.join(IrOp.registry[name].impls[provider].uuid() for provider in p) + }" + for name, p in priorities.items() + ) + + return f"{super().uuid()}|{priorities_str}|{impl_uuids_str}" diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 5f75fc8db..057174141 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -14,6 +14,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.system_utils import set_env_var +from .ir.lowering_pass import VllmIRLoweringPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass if rocm_aiter_ops.is_enabled(): @@ -99,8 +100,17 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] else: logger.debug("Skipping %s with compile range %s", pass_, compile_range) - # post-cleanup goes before fix_functionalization - # because it requires a functional graph + # perform the first post-cleanup before IR lowering to clean up fusion artifacts + # and make sure no dead IR ops are lowered. + self.post_cleanup(graph) + VllmInductorPass.dump_prefix += 1 + + # lowering before cleanup so DCE can clean up lowered ops. + # DCE handles mutating ops correctly as well. + self.ir_lowering(graph) + VllmInductorPass.dump_prefix += 1 + + # clean up after lowering again self.post_cleanup(graph) VllmInductorPass.dump_prefix += 1 @@ -152,7 +162,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] self.passes += [SplitCoalescingPass(config)] self.passes += [QKNormRoPEFusionPass(config)] - # needs a functional graph + self.ir_lowering = VllmIRLoweringPass(config) self.post_cleanup = PostCleanupPass(config) self.fix_functionalization = FixFunctionalizationPass(config) @@ -171,6 +181,10 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()} for pass_ in self.passes: passes.append(pass_.uuid()) + + passes.append(self.post_cleanup.uuid()) + passes.append(self.ir_lowering.uuid()) + passes.append(self.post_cleanup.uuid()) passes.append(self.fix_functionalization.uuid()) # Include the compile range in the uuid to ensure that inductor diff --git a/vllm/compilation/passes/vllm_inductor_pass.py b/vllm/compilation/passes/vllm_inductor_pass.py index 4eac620d1..46c3fe770 100644 --- a/vllm/compilation/passes/vllm_inductor_pass.py +++ b/vllm/compilation/passes/vllm_inductor_pass.py @@ -152,6 +152,7 @@ class VllmPatternMatcherPass(VllmInductorPass): f"auto_functionalized as auto_functionalized\n" f"from torch._inductor.pattern_matcher import *\n" f"vllm = torch.ops.vllm", + "vllm_ir = torch.ops.vllm_ir", file=f, ) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 1d09e2b7d..916c5a002 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -466,6 +466,15 @@ class CompilationConfig: disabled when running with Inductor: mode>CompilationMode.NONE and backend="inductor". Inductor generates (fused) Triton kernels for disabled custom ops.""" + + ir_enable_torch_wrap: bool = None # type: ignore[assignment] + """If True, enable vllm_ir torch custom op wrapping during the forward pass. + When False, torch custom op wrapping is disabled, allowing Dynamo to trace the + selected implementation directly or avoiding torch custom op overhead in eager mode. + Defaults to True when using Inductor with vllm-compile + (backend=="inductor" and mode == VLLM_COMPILE), False otherwise. + """ + splitting_ops: list[str] | None = None """A list of ops to exclude from cudagraphs, used in piecewise compilation. @@ -830,6 +839,7 @@ class CompilationConfig: "cudagraph_mode", "max_cudagraph_capture_size", "use_inductor_graph_partition", + "ir_enable_torch_wrap", mode="wrap", ) @classmethod diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index 4476cd125..f8494111d 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -1,13 +1,106 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import contextlib from collections.abc import Callable -from typing import Any, Literal +from dataclasses import asdict, fields +from typing import TYPE_CHECKING, Any, Literal -from pydantic import field_validator +from pydantic import Field, field_validator + +from vllm.config.utils import config, get_hash_factors, hash_factors +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +@config +class IrOpPriorityConfig: + """ + Configuration for vLLM IR op priority for dispatching/lowering during the + forward pass. Each member is a list of strings, which will be passed to + vllm.ir.ops..set_priority() for the duration of the forward pass. + A single comma-separated string is accepted as well, + + If specified manually, platform defaults will be appended to the lists. + See KernelConfig.set_platform_defaults(). + """ + + rms_norm: list[str] = Field(default_factory=list) + """Priority list for vllm.ir.ops.rms_norm""" + + def compute_hash(self) -> str: + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Any future fields that don't affect compilation should be excluded. + + Also, manually add IR op impl UUIDs to make sure they affect the compile cache. + """ + factors = get_hash_factors(self, set()) + + # Implementations are hidden from Dynamo, + # so they don't show up in the traced files list. + from vllm.ir.op import IrOp + + assert "_impls" not in factors + factors["_impls"] = { + name: { + provider: IrOp.registry[name].impls[provider].uuid() for provider in p + } + for name, p in asdict(self).items() + } + + return hash_factors(factors) + + @field_validator("*", mode="before") + @classmethod + def _to_list_str(cls, value: str | list[str]): + if isinstance(value, str): + value = value.replace(" ", "").split(",") + + assert all(isinstance(v, str) for v in value) + return value + + @contextlib.contextmanager + def set_priority(self): + """ + Context manager to set the IR op priority for all op members. + It also imports vllm.kernels to ensure all implementations are made available. + """ + import vllm.kernels # noqa: F401, registers IR op implementations + from vllm.ir.op import IrOp + + with contextlib.ExitStack() as stack: + for field in fields(self): + op_priority = getattr(self, field.name) + assert op_priority is not None, ( + f"IR op priority for {field.name} must be set" + ) + logger.debug( + "Setting IR op priority for %s to %s", field.name, op_priority + ) + ir_op = IrOp.registry[field.name] + stack.enter_context(ir_op.set_priority(op_priority)) + + yield + + @classmethod + def with_default( + cls, default: list[str], /, **kwargs: list[str] + ) -> "IrOpPriorityConfig": + """ + A helper to create an IrOpPriorityConfig where fields not specified in kwargs + use the given default list. + """ + for field in fields(cls): + if field.name not in kwargs: + kwargs[field.name] = list(default) + + return cls(**kwargs) -from vllm.config.utils import config -from vllm.utils.hashing import safe_hash MoEBackend = Literal[ "auto", @@ -26,6 +119,12 @@ MoEBackend = Literal[ class KernelConfig: """Configuration for kernel selection and warmup behavior.""" + ir_op_priority: IrOpPriorityConfig = Field(default_factory=IrOpPriorityConfig) + """ + vLLM IR op priority for dispatching/lowering during the forward pass. + Platform defaults appended automatically during VllmConfig.__post_init__. + """ + enable_flashinfer_autotune: bool = None # type: ignore[assignment] """If True, run FlashInfer autotuning during kernel warmup.""" @@ -51,21 +150,17 @@ class KernelConfig: def compute_hash(self) -> str: """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Any future fields that don't affect compilation should be excluded. """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() - return hash_str + ignored_factors = { + "enable_flashinfer_autotune", + "ir_op_priority", # handled separately below + } + factors = get_hash_factors(self, ignored_factors) + factors["ir_op_priority"] = self.ir_op_priority.compute_hash() + return hash_factors(factors) @field_validator("enable_flashinfer_autotune", mode="wrap") @classmethod @@ -74,3 +169,31 @@ class KernelConfig: if value is None: return value return handler(value) + + def set_platform_defaults(self, vllm_config: "VllmConfig") -> None: + """Set platform-specific defaults for the kernel config.""" + from vllm.platforms import current_platform + + platform_op_priority = current_platform.get_default_ir_op_priority(vllm_config) + logger.debug( + "Setting platform-specific IR op priority defaults: %s, user-defined: %s", + platform_op_priority, + self.ir_op_priority, + ) + for op_name, op_priority in asdict(platform_op_priority).items(): + current_op_priority: list[str] = getattr(self.ir_op_priority, op_name) + if current_op_priority is None: + setattr(self.ir_op_priority, op_name, op_priority) + else: + # Append platform-specific priorities + # Must be idempotent because vllm_config.set_platform_defaults() may be + # called multiple times (due to VllmConfig.__post_init__ manual call). + unique_op_priority = [ + op for op in op_priority if op not in current_op_priority + ] + current_op_priority.extend(unique_op_priority) + + logger.info( + "Final IR op priority after setting platform defaults: %s", + self.ir_op_priority, + ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 5e7b3fcd5..fad3e0ed2 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -95,9 +95,11 @@ def enable_norm_fusion(cfg: "VllmConfig") -> bool: """Enable if either RMS norm or quant FP8 custom op is active; otherwise Inductor handles fusion.""" - return cfg.compilation_config.is_custom_op_enabled( - "rms_norm" - ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8") + return ( + cfg.compilation_config.is_custom_op_enabled("rms_norm") + or cfg.compilation_config.is_custom_op_enabled("quant_fp8") + or cfg.kernel_config.ir_op_priority.rms_norm[0] != "native" + ) def enable_act_fusion(cfg: "VllmConfig") -> bool: @@ -417,6 +419,10 @@ class VllmConfig: vllm_factors.append(self.compilation_config.compute_hash()) else: vllm_factors.append("None") + if self.kernel_config: + vllm_factors.append(self.kernel_config.compute_hash()) + else: + vllm_factors.append(None) if self.kv_transfer_config: vllm_factors.append(self.kv_transfer_config.compute_hash()) else: @@ -890,6 +896,13 @@ class VllmConfig: else: self.compilation_config.mode = CompilationMode.NONE + # By default, enable torch wrapping only when using custom Inductor lowering + if self.compilation_config.ir_enable_torch_wrap is None: + self.compilation_config.ir_enable_torch_wrap = ( + self.compilation_config.mode == CompilationMode.VLLM_COMPILE + and self.compilation_config.backend == "inductor" + ) + if all(s not in self.compilation_config.custom_ops for s in ("all", "none")): if ( self.compilation_config.backend == "inductor" @@ -899,6 +912,11 @@ class VllmConfig: else: self.compilation_config.custom_ops.append("all") + # This populates IR op priorities, + # must happen after compilation mode and backend are decided, + # but before fusion defaults are applied as those may depend on op priority. + self.kernel_config.set_platform_defaults(self) + default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level] self._apply_optimization_level_defaults(default_config) if self.kernel_config.enable_flashinfer_autotune is None: @@ -1706,7 +1724,8 @@ class VllmConfig: f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa f"pooler_config={self.model_config.pooler_config!r}, " - f"compilation_config={self.compilation_config!r}" + f"compilation_config={self.compilation_config!r}, " + f"kernel_config={self.kernel_config!r}" ) def validate_block_size(self) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0c9cf2ae9..d498135ce 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,7 +8,7 @@ import functools import json import sys from collections.abc import Callable -from dataclasses import MISSING, dataclass, fields, is_dataclass +from dataclasses import MISSING, asdict, dataclass, fields, is_dataclass from itertools import permutations from types import UnionType from typing import ( @@ -70,7 +70,7 @@ from vllm.config.cache import ( PrefixCachingHashAlgo, ) from vllm.config.device import Device -from vllm.config.kernel import MoEBackend +from vllm.config.kernel import IrOpPriorityConfig, MoEBackend from vllm.config.lora import MaxLoRARanks from vllm.config.model import ( ConvertOption, @@ -401,6 +401,7 @@ class EngineArgs: max_cudagraph_capture_size: int | None = get_field( CompilationConfig, "max_cudagraph_capture_size" ) + ir_op_priority: IrOpPriorityConfig = get_field(KernelConfig, "ir_op_priority") # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. @@ -657,6 +658,9 @@ class EngineArgs: self.weight_transfer_config = WeightTransferConfig( **self.weight_transfer_config ) + if isinstance(self.ir_op_priority, dict): + self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority) + # Setup plugins from vllm.plugins import load_general_plugins @@ -1293,6 +1297,7 @@ class EngineArgs: title="KernelConfig", description=KernelConfig.__doc__, ) + kernel_group.add_argument("--ir-op-priority", **kernel_kwargs["ir_op_priority"]) kernel_group.add_argument( "--enable-flashinfer-autotune", **kernel_kwargs["enable_flashinfer_autotune"], @@ -1917,6 +1922,22 @@ class EngineArgs: if self.moe_backend != "auto": kernel_config.moe_backend = self.moe_backend + # Transfer top-level ir_op_priority into KernelConfig.ir_op_priority + for op_name, op_priority in asdict(self.ir_op_priority).items(): + # Empty means unset + if not op_priority: + continue + + # Priority cannot be set 2x for the same op + if getattr(kernel_config.ir_op_priority, op_name): + raise ValueError( + f"Op priority for {op_name} specified via both ir_op_priority " + f"and KernelConfig.ir_op_priority, only one allowed at a time." + ) + + # Set the attribute + setattr(kernel_config.ir_op_priority, op_name, op_priority) + load_config = self.create_load_config() # Pass reasoning_parser into StructuredOutputsConfig diff --git a/vllm/forward_context.py b/vllm/forward_context.py index a7aaeff4f..fa568c33f 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -10,6 +10,7 @@ from typing import Any import torch import vllm.envs as envs +import vllm.ir from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -378,7 +379,13 @@ def set_forward_context( ) try: - with override_forward_context(forward_context): + with ( + override_forward_context(forward_context), + vllm_config.kernel_config.ir_op_priority.set_priority(), + vllm.ir.enable_torch_wrap( + vllm_config.compilation_config.ir_enable_torch_wrap + ), + ): yield finally: global last_logging_time, batchsize_logging_interval diff --git a/vllm/ir/__init__.py b/vllm/ir/__init__.py new file mode 100644 index 000000000..cef8df115 --- /dev/null +++ b/vllm/ir/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from . import ops +from .op import enable_torch_wrap, register_op + +__all__ = ["enable_torch_wrap", "register_op", "ops"] diff --git a/vllm/ir/op.py b/vllm/ir/op.py new file mode 100644 index 000000000..1cbd78e28 --- /dev/null +++ b/vllm/ir/op.py @@ -0,0 +1,414 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import inspect +from collections.abc import Callable +from pathlib import Path +from typing import Any, ClassVar, overload + +import torch +from torch.library import Library, infer_schema + +from vllm.ir.util import hash_source, weak_cache +from vllm.logger import init_logger +from vllm.logging_utils import lazy, tensors_str_no_data + +vllm_ir_lib = Library("vllm_ir", "FRAGMENT") + +logger = init_logger(__name__) + +RESERVED_PROVIDERS = ["native", "unfused"] +"""Providers that are reserved and cannot be used for custom implementations.""" + +_ENABLE_TORCH_WRAP: bool = True +"""Global override flag to control torch op layer wrapping.""" + + +@contextlib.contextmanager +def enable_torch_wrap(enable: bool = True): + """ + Context manager to enable/disable torch custom op wrapping for vLLM IR ops. + When torch wrapping is disabled, the torch custom op layer is skipped + and IR ops dispatch directly to the implementation. + Helpful for avoiding torch dispatch overhead in eager mode + and avoiding the need for lowering for platforms not using Inductor. + """ + + global _ENABLE_TORCH_WRAP + old = _ENABLE_TORCH_WRAP + try: + _ENABLE_TORCH_WRAP = enable + yield + finally: + _ENABLE_TORCH_WRAP = old + + +# 0-param decorator overload +@overload +def register_op(f: Callable[..., Any]) -> "IrOp": ... + + +# parametrized decorator overload +@overload +def register_op( + *, + name: str | None = None, +) -> Callable[[Callable[..., Any]], "IrOp"]: ... + + +def register_op( + f: Callable | None = None, + *, + name: str | None = None, +) -> "IrOp | Callable[[Callable], IrOp]": + """ + Register a new vLLM IR op. + + :param f: the native implementation of the op + :param name: the name of the op, defaults to the function name + :return: the IrOp object if f is provided, otherwise a decorator + + Example usage: + ```python + @vllm.ir.register_op + def my_op(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + + @vllm.ir.register_op(name="custom_mul") + def multiply(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x * y""" + + def decorator(_f: Callable): + op_name: str = _f.__name__ if name is None else name + assert op_name not in IrOp.registry + op = IrOp(op_name, _f) + IrOp.registry[op_name] = op + return op + + if f is not None: + return decorator(f) + + return decorator + + +class IrOp: + registry: ClassVar[dict[str, "IrOp"]] = {} + + name: str + impls: dict[str, "IrOpImpl"] + + def __init__(self, name: str, native_impl: Callable): + self._py_signature = inspect.signature(native_impl) + if any( + p.kind == inspect.Parameter.KEYWORD_ONLY + for p in self._py_signature.parameters.values() + ): + raise ValueError( + f"Op {name} has keyword-only arguments which are not currently " + f"supported. That's because kwargs are not allowed during lowering." + ) + + self.name = name + self.impls: dict[str, IrOpImpl] = {} + self._priority_impls: list[IrOpImpl] = [] + self._schema_str = infer_schema(native_impl, mutates_args=[]) + + # native implementation + self.impls["native"] = IrOpImpl( + self, "native", native_impl, supported=True, supports_args=None + ) + + # By default, fake routes directly to native, + # can be overridden by register_fake + self._fake_fn = native_impl + + # torch registration + vllm_ir_lib.define(self.name + self._schema_str) + # CompositeExplicitAutograd is not decomposed + # by ATen IR normalization in AOTAutograd + vllm_ir_lib.impl( + self.name, self._inner_call, dispatch_key="CompositeExplicitAutograd" + ) + vllm_ir_lib._register_fake(self.name, self._fake_call) + assert hasattr(torch.ops.vllm_ir, name) + self.torch_op: torch._ops.OpOverload = getattr(torch.ops.vllm_ir, name).default + + def register_fake(self, fn: Callable) -> Callable: + """ + Register a fake impl for the torch custom op. If this method is not called, + the native implementation is used directly for the fake implementation. + """ + self._fake_fn = fn + return fn + + def _fake_call(self, *args, **kwargs) -> Any: + """ + Call to the fake implementation of the op. We use indirection because we want + users to be able to register fake later but also want it to fall back to native + directly by default, instead of going through the dispatching mechanism. + """ + return self._fake_fn(*args, **kwargs) + + def register_impl( + self, + provider: str, + *, + supported: bool = True, + supports_args: Callable[..., bool] | None = None, + ): + """ + Register an implementation for this custom op. + :param provider: The name of the provider, must be unique. + :param supported: Static support check, use this to check platform support. + :param supports_args: Dynamic arg support check, used for types and shapes. + :return: A decorator that registers the implementation. + + The decorated function must have the same semantics and signature as + the native implementation. + + The provider name must be unique and not one of the RESERVED_PROVIDERS. + The supported and supports_args parameters should not be used to implement + custom enablement logic based on global state (e.g. environment variables). + Instead, supported param should only be used to check for platform support + (e.g. whether a specific hardware or library is available). + supports_args should be used to check whether the provided arguments are + compatible with the implementation. + For custom enablement logic, set op impl priority. + + Example: + ```python + @my_op.register_impl("my_provider", supported=torch.cuda.is_available()) + def my_provider_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ... + ``` + + """ + assert provider not in RESERVED_PROVIDERS, ( + f"Provider name {provider} is reserved." + ) + + def _register_impl(f: Callable): + impl = IrOpImpl(self, provider, f, supported, supports_args) + self.impls[provider] = impl + + if self.get_priority(): + logger.warning( + "Warning: registering new impl %s for op %s while priority is set.", + provider, + self.name, + ) + + return impl + + return _register_impl + + def _inner_call(self, *args, **kwargs) -> Any: + """ + Eager call to torch op lands here. When torch wrapping is disabled, + __call__ routes straight here instead of going through torch op dispatching. + """ + impl = self.dispatch(*args, **kwargs) + return impl.impl_fn(*args, **kwargs) + + def apply_arg_defaults(self, args) -> tuple: + """ + Return args with default values applied. + Defaults are taken from the native implementation signature. + + SHOULD NOT BE USED IN THE DISPATCH PATH (SLOW). + Only for Inductor lowering. + """ + bound_args = self._py_signature.bind(*args) + bound_args.apply_defaults() + return bound_args.args + + def dispatch(self, *args, **kwargs) -> "IrOpImpl": + """ + Dispatch to the appropriate implementation based on current priority + and argument support checks. Returns the selected IrOpImpl. + + THIS FUNCTION IS ON THE HOT PATH (OP DISPATCH), MUST BE FAST. + """ + if not self._priority_impls: + if not torch.compiler.is_compiling(): + # Logging not compatible with Dynamo tracing + # (this code is exposed when torch wrapping is disabled) + logger.warning_once( + "Priority not set for op %s, using native implementation.", + self.name, + ) + return self.impls["native"] + + for impl in self._priority_impls: + if not impl.supported: + raise ValueError( + f"Implementation {impl.provider} for op {self.name} not supported. " + f"All implementations in priority list must be supported." + ) + if impl.supports_args(*args, **kwargs): + return impl + + if not torch.compiler.is_compiling(): + logger.debug( + "Skipping provider %s because it does not support " + "%s with args=%s kwargs=%s", + impl.provider, + self.name, + lazy(lambda: tensors_str_no_data(args)), + lazy(lambda: tensors_str_no_data(kwargs)), + ) + + raise RuntimeError( + "Priority set incorrectly: the last implementation must " + "support all args (can be native). This is likely an internal bug" + ) + + def __call__(self, *args, **kwargs) -> Any: + if not _ENABLE_TORCH_WRAP: + return self._inner_call(*args, **kwargs) + + return self.torch_op(*args, **kwargs) + + def get_priority(self) -> list[str]: + """Get the current dispatch priority for implementations for this op.""" + return [p.provider for p in self._priority_impls] + + @contextlib.contextmanager + def set_priority(self, priority: list[str]): + """ + Context manager to set the dispatch priority for implementations for this op. + """ + assert all(p in self.impls for p in priority), ( + "All providers in priority must be registered implementations." + ) + + def filter_priority_impls(p_list: list[str]) -> list[IrOpImpl]: + filtered_impls = [] + for p in p_list: + impl = self.impls[p] + if not impl.supported: + # Skip unsupported implementations + continue + + filtered_impls.append(impl) + + # If all args are supported, skip other implementations + if impl.supports_all_args: + return filtered_impls + + logger.warning_once( + "Op %s: No implementation in priority list supports all args, " + "execution fallback to native is possible. To silence this warning, " + "explicitly add 'native' to the end of the priority list", + self.name, + ) + filtered_impls.append(self.impls["native"]) + return filtered_impls + + # Temporarily set priority + old_priority_impls = self._priority_impls + try: + self._priority_impls = filter_priority_impls(priority) + yield + finally: + self._priority_impls = old_priority_impls + + def supported_providers(self) -> list[str]: + return [p.provider for p in self.impls.values() if p.supported] + + +class IrOpImpl: + def __init__( + self, + op: IrOp, + provider: str, + impl_fn: Callable, + supported: bool, + supports_args: Callable[..., bool] | None, + ): + assert provider not in op.impls, ( + f"Implementation for provider {provider} already registered." + ) + # Native also uses this path, so we allow it here. + assert provider == "native" or provider not in RESERVED_PROVIDERS + + # Enforce the exact same schema as the native implementation. + # This takes care of names, types, and defaults. + schema = infer_schema(impl_fn, mutates_args=[]) + if schema != op._schema_str: + raise ValueError( + f"Implementation for provider {provider} has schema '{schema}' which " + f"does not match native schema '{op._schema_str}' for op {op.name}." + ) + + if supports_args is not None: + if not callable(supports_args): + raise ValueError( + f"supports_args for provider {provider} must be a callable" + ) + + # We also manually validate the supports_args signature. + # Matching signatures allow faster dispatch on the hotpath. + + # Check that supports_args does not have keyword-only parameters + supports_args_signature = inspect.signature(supports_args) + params = supports_args_signature.parameters + if any(p.kind == inspect.Parameter.KEYWORD_ONLY for p in params.values()): + raise ValueError( + f"supports_args for provider {provider} " + f"cannot have keyword-only parameters" + ) + + # Check that supports_args has the same total number of parameters + op_params = op._py_signature.parameters + if len(params) != len(op_params): + raise ValueError( + f"supports_args for provider {provider} must have the same number " + f"of parameters ({len(params)}) as the native implementation " + f"({len(op_params)})" + ) + + # Check that names and defaults match for supports_args + for p, op_p in zip(params.values(), op_params.values()): + if p.name != op_p.name: + raise ValueError( + f"supports_args for provider {provider} has parameter " + f"'{p.name}' which does not match native parameter " + f"'{op_p.name}'" + ) + if p.default != op_p.default: + raise ValueError( + f"supports_args for provider {provider} has parameter " + f"'{p.name}' with default {p.default} which does not match " + f"native default {op_p.default}'" + ) + + self.op = op + self.provider = provider + self.impl_fn = impl_fn + self.supported = supported + self._supports_args = supports_args + + @property + def supports_all_args(self) -> bool: + """Check if this implementation supports all args unconditionally.""" + return self._supports_args is None + + def supports_args(self, *args, **kwargs) -> bool: + if self._supports_args is None: + return True + + return self._supports_args(*args, **kwargs) + + @weak_cache + def uuid(self): + """ + Compile-time hash to uniquely determine whether the implementation has changed. + Used by vllm-compile hash mechanism and torch.compile lowering pass uuid to + control the vLLM compile cache and AOTAutograd/Inductor caches respectively. + + Source file contents do not change so we cache uuid. + TODO(luka): Cache the file hash as multiple impls are likely in the same file. + """ + sources = [Path(inspect.getfile(self.impl_fn))] + return hash_source(*sources) diff --git a/vllm/ir/ops/__init__.py b/vllm/ir/ops/__init__.py new file mode 100644 index 000000000..25ad27c8a --- /dev/null +++ b/vllm/ir/ops/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .layernorm import rms_norm + +__all__ = ["rms_norm"] diff --git a/vllm/ir/ops/layernorm.py b/vllm/ir/ops/layernorm.py new file mode 100644 index 000000000..8471aa043 --- /dev/null +++ b/vllm/ir/ops/layernorm.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import Tensor + +from ..op import register_op + + +@register_op +def rms_norm( + x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None +) -> Tensor: + """Weighted root-mean-square layer normalization""" + orig_dtype = x.dtype + x = x.to(torch.float32) + x_var = x if variance_size is None else x[..., :variance_size] + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + epsilon) + x = x.to(orig_dtype) + if weight is not None: + x = x * weight + return x diff --git a/vllm/ir/util.py b/vllm/ir/util.py new file mode 100644 index 000000000..ac8a06155 --- /dev/null +++ b/vllm/ir/util.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +import hashlib +import inspect +import types +import weakref +from pathlib import Path +from typing import Any + + +def hash_source(*srcs: str | Any) -> str: + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if src is None: + src_str = "None" + elif isinstance(src, str): + src_str = src + elif isinstance(src, Path): + src_str = src.read_text() + elif isinstance(src, (types.FunctionType, type)): + src_str = inspect.getsource(src) + else: + # object instance + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.hexdigest() + + +def weak_lru_cache(maxsize: int | None = 128, typed: bool = False): + """ + LRU Cache decorator that keeps a weak reference to 'self'. + This avoids memory leakage, which happens when functools.lru_cache + stores a reference to self in the global cache. + + Taken from: https://stackoverflow.com/a/68052994/5082708 + """ + + def wrapper(func): + @functools.lru_cache(maxsize, typed) + def _func(_self, *args, **kwargs): + return func(_self(), *args, **kwargs) + + @functools.wraps(func) + def inner(self, *args, **kwargs): + return _func(weakref.ref(self), *args, **kwargs) + + return inner + + return wrapper + + +def weak_cache(user_function, /): + """Simple weak equivalent to functools.cache""" + return weak_lru_cache(maxsize=None)(user_function) diff --git a/vllm/kernels/__init__.py b/vllm/kernels/__init__.py index 3d0c9805e..075bc01f3 100644 --- a/vllm/kernels/__init__.py +++ b/vllm/kernels/__init__.py @@ -1,3 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Kernel implementations for vLLM.""" + +from . import aiter_ops, oink_ops, vllm_c, xpu_ops + +__all__ = ["vllm_c", "aiter_ops", "oink_ops", "xpu_ops"] diff --git a/vllm/kernels/aiter_ops.py b/vllm/kernels/aiter_ops.py new file mode 100644 index 000000000..1980051dd --- /dev/null +++ b/vllm/kernels/aiter_ops.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools + +import torch +from torch import Tensor +from torch.library import Library + +from vllm import ir +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op + +current_platform.import_kernels() + + +def is_aiter_found() -> bool: + from importlib.util import find_spec + + return find_spec("aiter") is not None + + +aiter_lib = Library("vllm_aiter", "FRAGMENT") +""" +This library holds torch custom ops for wrapped AITER ops. +Many AITER ops want to remain invisible to torch.compile even after lowering. +They are thus wrapped into torch custom ops inside the IR op implementations. +""" + +direct_register_aiter_op = functools.partial( + direct_register_custom_op, target_lib=aiter_lib +) +"""Syntactic sugar for registering AITER custom ops.""" + +AITER_SUPPORTED = is_aiter_found() +"""Most kernels in this file are supported if AITER is installed.""" + +rms_no_var_16bit_only = ( + lambda x, weight, epsilon, variance_size=None: variance_size is None + and x.dtype + in ( + torch.float16, + torch.bfloat16, + ) +) +"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override.""" + + +@ir.ops.rms_norm.register_impl( + "aiter", supports_args=rms_no_var_16bit_only, supported=AITER_SUPPORTED +) +def rms_norm( + x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None +) -> Tensor: + assert variance_size is None + assert x.dtype in (torch.float16, torch.bfloat16) + if weight is None: + weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype) + return torch.ops.vllm_aiter.rms_norm(x, weight, epsilon) + + +def _rms_norm_impl(x: Tensor, weight: Tensor, variance_epsilon: float) -> Tensor: + from aiter import rms_norm + + if x.dim() > 2: + x_original_shape = x.shape + x = x.reshape(-1, x_original_shape[-1]) + x = rms_norm(x, weight, variance_epsilon) + return x.reshape(x_original_shape) + + return rms_norm(x, weight, variance_epsilon) + + +def _rms_norm_fake(x: Tensor, weight: Tensor, variance_epsilon: float) -> Tensor: + return torch.empty_like(x) + + +direct_register_aiter_op( + op_name="rms_norm", op_func=_rms_norm_impl, fake_impl=_rms_norm_fake +) diff --git a/vllm/kernels/oink_ops.py b/vllm/kernels/oink_ops.py new file mode 100644 index 000000000..e8e3cb91f --- /dev/null +++ b/vllm/kernels/oink_ops.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm import ir +from vllm.platforms import current_platform + +OINK_AVAILABLE = current_platform.has_device_capability(100) and hasattr( + torch.ops, "oink" +) + + +def has_oink_op(name: str) -> bool: + """Check if a specific oink op is registered.""" + return OINK_AVAILABLE and hasattr(torch.ops.oink, name) + + +def _can_view_as_2d(x: torch.Tensor) -> bool: + """Return True if x.view(-1, x.shape[-1]) is viewable (no copy).""" + if x.dim() < 2: + return False + if x.dim() == 2: + return True + # For a view(-1, N) to be valid, all leading dims must be contiguous with + # respect to each other (size-1 dims are ignored). + for dim in range(x.dim() - 1): + # Strides for size-1 dims are irrelevant and can be arbitrary. + if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size( + dim + 1 + ): + return False + return True + + +def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool: + """Return True if x_2d meets Oink's pointer-path stride constraints.""" + if x_2d.dim() != 2: + return False + if x_2d.stride(1) != 1: + return False + # Match Oink's vectorization constraint: stride(0) divisible by 256b. + if x_2d.dtype in (torch.float16, torch.bfloat16): + divby = 16 + elif x_2d.dtype == torch.float32: + divby = 8 + else: + return False + return (x_2d.stride(0) % divby) == 0 + + +oink_rms_supported = ( + lambda x, weight, epsilon, variance_size=None: variance_size is None + and weight is not None + and x.dim() >= 2 + and x.dtype == weight.dtype + and weight.is_contiguous() + and _can_view_as_2d(x) + and _is_oink_stride_compatible_2d(x.view(-1, x.shape[-1])) +) +""" +Oink rms only supports 2d-like inputs with contiguous weight +and no variance_size override. +""" + + +@ir.ops.rms_norm.register_impl( + "oink", supports_args=oink_rms_supported, supported=has_oink_op("rmsnorm") +) +def rms_norm( + x: torch.Tensor, + weight: torch.Tensor | None, + epsilon: float, + variance_size: int | None = None, +) -> torch.Tensor: + assert variance_size is None + x_2d = x.view(-1, x.shape[-1]) + return torch.ops.oink.rmsnorm(x_2d, weight, epsilon).view_as(x) diff --git a/vllm/kernels/vllm_c.py b/vllm/kernels/vllm_c.py new file mode 100644 index 000000000..fabb36d7b --- /dev/null +++ b/vllm/kernels/vllm_c.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import Tensor + +from vllm import ir +from vllm.platforms import current_platform + +current_platform.import_kernels() + +CUDA_ALIKE = current_platform.is_cuda_alike() +"""Most kernels in this file are supported on all CUDA-alike platforms.""" + +rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None +"""vLLM kernel does not support variance_size parameter.""" + + +@ir.ops.rms_norm.register_impl( + "vllm_c", supports_args=rms_no_var_size, supported=CUDA_ALIKE +) +def rms_norm( + x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None +) -> Tensor: + if weight is None: + # Kernel requires weight tensor, pass ones + weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype) + assert variance_size is None + output = torch.empty(x.shape, device=x.device, dtype=x.dtype) + torch.ops._C.rms_norm(output, x, weight, epsilon) + return output diff --git a/vllm/kernels/xpu_ops.py b/vllm/kernels/xpu_ops.py new file mode 100644 index 000000000..3548fb868 --- /dev/null +++ b/vllm/kernels/xpu_ops.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import Tensor + +from vllm import ir +from vllm.platforms import current_platform + +current_platform.import_kernels() + + +def is_xpu_kernels_found() -> bool: + from importlib.util import find_spec + + return find_spec("vllm_xpu_kernels") is not None + + +XPU_KERNELS_SUPPORTED = is_xpu_kernels_found() +"""Kernels in this file are supported if vLLM XPU kernels are installed.""" + +rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None + + +@ir.ops.rms_norm.register_impl( + "xpu_kernels", supports_args=rms_no_var, supported=XPU_KERNELS_SUPPORTED +) +def rms_norm( + x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None +) -> Tensor: + if weight is None: + # Kernel requires weight tensor, pass ones + weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype) + assert variance_size is None + output = torch.empty(x.shape, device=x.device, dtype=x.dtype) + torch.ops._C.rms_norm(output, x, weight, epsilon) + return output diff --git a/vllm/logging_utils/__init__.py b/vllm/logging_utils/__init__.py index 94dee07ed..b83d499db 100644 --- a/vllm/logging_utils/__init__.py +++ b/vllm/logging_utils/__init__.py @@ -8,6 +8,7 @@ from vllm.logging_utils.access_log_filter import ( from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter from vllm.logging_utils.lazy import lazy from vllm.logging_utils.log_time import logtime +from vllm.logging_utils.torch_tensor import tensors_str_no_data __all__ = [ "NewLineFormatter", @@ -16,4 +17,5 @@ __all__ = [ "create_uvicorn_log_config", "lazy", "logtime", + "tensors_str_no_data", ] diff --git a/vllm/logging_utils/torch_tensor.py b/vllm/logging_utils/torch_tensor.py new file mode 100644 index 000000000..7af4326ba --- /dev/null +++ b/vllm/logging_utils/torch_tensor.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + + +def tensors_str_no_data(arg: Any): + from torch._tensor_str import printoptions + + with printoptions(threshold=1, edgeitems=0): + return str(arg) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7fa804587..500370d9f 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -6,7 +6,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from vllm import _oink_ops, envs +# Import kernels +import vllm.kernels # noqa: F401 +from vllm import _oink_ops, envs, ir from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -51,23 +53,6 @@ def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool: return (x_2d.stride(0) % divby) == 0 -def rms_norm( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float -) -> torch.Tensor: - from vllm import _custom_ops as ops - - if envs.VLLM_BATCH_INVARIANT: - return rms_norm_batch_invariant(x, weight, variance_epsilon) - out = torch.empty_like(x) - ops.rms_norm( - out, - x, - weight, - variance_epsilon, - ) - return out - - def fused_add_rms_norm( x: torch.Tensor, residual: torch.Tensor, @@ -105,23 +90,16 @@ def poly_norm( return out -def dispatch_rocm_rmsnorm_func( - with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False -): +def dispatch_rocm_rmsnorm_func(dtype: torch.dtype, use_aiter: bool = False): use_aiter = use_aiter and dtype in [ torch.float16, torch.bfloat16, ] - if use_aiter and with_fused_add: - return rocm_aiter_ops.rms_norm2d_with_add if use_aiter: - return rocm_aiter_ops.rms_norm - - # fall back to CUDA implementation - if with_fused_add: + return rocm_aiter_ops.rms_norm2d_with_add + else: return fused_add_rms_norm - return rms_norm # --8<-- [start:rms_norm] @@ -158,20 +136,14 @@ class RMSNorm(CustomOp): if current_platform.is_rocm(): aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled() - self.rocm_norm_func = dispatch_rocm_rmsnorm_func( - with_fused_add=False, - dtype=weight_dtype, - use_aiter=aiter_rmsnorm_enabled, - ) self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( - with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled + dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled ) # Optional: enable Oink Blackwell RMSNorm custom-op fast path on # compatible CUDA devices (e.g., SM100) when the external Oink # package is available. This is detected once at construction time # to avoid per-call device queries in the hot path. - self._use_oink_rmsnorm = False self._use_oink_fused_add_rmsnorm = False if ( not current_platform.is_rocm() @@ -203,7 +175,6 @@ class RMSNorm(CustomOp): try: device_index = torch.accelerator.current_device_index() if _oink_ops.is_oink_available_for_device(device_index): - self._use_oink_rmsnorm = True self._use_oink_fused_add_rmsnorm = ( _oink_ops.has_fused_add_rms_norm() ) @@ -215,7 +186,6 @@ class RMSNorm(CustomOp): "RMSNorm; falling back to vLLM RMSNorm. Error: %s", e, ) - self._use_oink_rmsnorm = False self._use_oink_fused_add_rmsnorm = False @staticmethod @@ -270,6 +240,10 @@ class RMSNorm(CustomOp): residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" + if residual is None: + return ir.ops.rms_norm( + x, self.weight.data, self.variance_epsilon, self.variance_size_override + ) return self.forward_static( x, @@ -286,35 +260,14 @@ class RMSNorm(CustomOp): x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if residual is None and not envs.VLLM_BATCH_INVARIANT: + return ir.ops.rms_norm( + x, self.weight.data, self.variance_epsilon, self.variance_size_override + ) + if self.variance_size_override is not None: return self.forward_native(x, residual) - # Optional Oink SM100 fast path (no residual). This path is - # torch.compile-friendly via torch.ops.oink.rmsnorm and preserves - # 2D layouts (including padded rows) when using the Oink - # pointer-based kernel. - if ( - residual is None - and getattr(self, "_use_oink_rmsnorm", False) - and x.is_cuda - and x.dim() >= 2 - and self.has_weight - and not envs.VLLM_BATCH_INVARIANT - and self.weight.data.dtype == x.dtype - and self.weight.data.is_contiguous() - ): - orig_shape = x.shape - hidden_size = orig_shape[-1] - if _can_view_as_2d(x): - x_2d = x.view(-1, hidden_size) - if _is_oink_stride_compatible_2d(x_2d): - y_2d = _oink_ops.rmsnorm( - x_2d, - self.weight.data, - self.variance_epsilon, - ) - return y_2d.view(orig_shape) - # Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place). # This mirrors vLLM's fused_add_rms_norm semantics by mutating both # `x` (normalized output) and `residual` (residual-out buffer). @@ -356,29 +309,34 @@ class RMSNorm(CustomOp): ) return x, residual - add_residual = residual is not None - if add_residual: + if residual is not None: return fused_add_rms_norm( x, residual, self.weight.data, self.variance_epsilon ) else: - return rms_norm(x, self.weight.data, self.variance_epsilon) + assert envs.VLLM_BATCH_INVARIANT + return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon) def forward_hip( self, x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if residual is None and not envs.VLLM_BATCH_INVARIANT: + return ir.ops.rms_norm( + x, self.weight.data, self.variance_epsilon, self.variance_size_override + ) + if self.variance_size_override is not None: return self.forward_native(x, residual) - add_residual = residual is not None - if add_residual: + if residual is not None: return self.rocm_norm_func_with_add( x, residual, self.weight.data, self.variance_epsilon ) else: - return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon) + assert envs.VLLM_BATCH_INVARIANT + return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon) def forward_xpu( self, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 71c8e3a35..73bfbeef1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -30,6 +30,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.config.cache import CacheDType + from vllm.config.kernel import IrOpPriorityConfig from vllm.v1.attention.selector import AttentionSelectorConfig else: VllmConfig = None @@ -550,6 +551,26 @@ class CudaPlatformBase(Platform): def use_custom_op_collectives(cls) -> bool: return True + @classmethod + def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConfig: + from vllm.config.compilation import CompilationMode + from vllm.config.kernel import IrOpPriorityConfig + + # Native used by default when compiling, + # use vllm_c kernels where available when no codegen + cc = vllm_config.compilation_config + using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE + default = ["native"] if using_inductor else ["vllm_c", "native"] + + # Use oink if enabled for rms_norm + # TODO(Laurawly/luka): remove this env var, + # users can just use IR op priority directly + rms_norm = default + if envs.VLLM_USE_OINK_OPS: + rms_norm = ["oink"] + default + + return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm) + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index fae37442e..7fba7a65f 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup from vllm.config import VllmConfig + from vllm.config.kernel import IrOpPriorityConfig from vllm.inputs import EngineInput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -931,6 +932,16 @@ class Platform: "num_compute_units is not implemented for the current platform." ) + @classmethod + def get_default_ir_op_priority( + cls, vllm_config: "VllmConfig" + ) -> "IrOpPriorityConfig": + """Get the default IR op priority for the current platform.""" + from vllm.config.kernel import IrOpPriorityConfig + + # Native always used by default. Platforms can override this behavior. + return IrOpPriorityConfig.with_default(["native"]) + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 13ff8821f..26b081b47 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -19,6 +19,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: from vllm.config import VllmConfig + from vllm.config.kernel import IrOpPriorityConfig from vllm.v1.attention.selector import AttentionSelectorConfig logger = init_logger(__name__) @@ -903,3 +904,32 @@ class RocmPlatform(Platform): @classmethod def use_custom_op_collectives(cls) -> bool: return True + + @classmethod + def get_default_ir_op_priority( + cls, vllm_config: "VllmConfig" + ) -> "IrOpPriorityConfig": + from vllm.config.compilation import CompilationMode + from vllm.config.kernel import IrOpPriorityConfig + + # Native used by default when compiling, + # use vllm_c kernels where available when no codegen + # TODO(luka/TJ) use aiter, vllm_c, native by default on ROCm + cc = vllm_config.compilation_config + using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE + default = ["native"] if using_inductor else ["vllm_c", "native"] + + # This (mostly) preserves previous CustomOp behavior + # Necessary on ROCm because it's common that users + # enable rms_norm to use the aiter kernel. + # TODO(luka/TJ) remove env vars completely + if ( + cc.is_custom_op_enabled("rms_norm") + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_RMSNORM + ): + rms_norm = ["aiter"] + default + else: + rms_norm = default + + return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index bf96b94af..2a56ff5c6 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -21,6 +21,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: from vllm.config import VllmConfig + from vllm.config.kernel import IrOpPriorityConfig from vllm.v1.attention.selector import AttentionSelectorConfig else: VllmConfig = None @@ -257,6 +258,21 @@ class XPUPlatform(Platform): ) return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa + @classmethod + def get_default_ir_op_priority( + cls, vllm_config: "VllmConfig" + ) -> "IrOpPriorityConfig": + from vllm.config.compilation import CompilationMode + from vllm.config.kernel import IrOpPriorityConfig + + # Native used by default when compiling, + # use fused kernels where available when no codegen + cc = vllm_config.compilation_config + using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE + default = ["native"] if using_inductor else ["xpu_kernels", "native"] + + return IrOpPriorityConfig.with_default(default) + @classmethod def device_count(cls) -> int: return torch.xpu.device_count()