[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: chzhang <chaojun.zhang@intel.com>
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com>
Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
This commit is contained in:
Luka Govedič
2026-03-31 22:15:05 -04:00
committed by GitHub
parent 0fab52f0aa
commit 40bb175027
49 changed files with 2177 additions and 265 deletions

View File

@@ -2,6 +2,16 @@ group: Kernels
depends_on: depends_on:
- image-build - image-build
steps: steps:
- label: vLLM IR Tests
timeout_in_minutes: 10
working_dir: "/vllm-workspace/"
source_file_dependencies:
- vllm/ir
- vllm/kernels
commands:
- pytest -v -s tests/ir
- pytest -v -s tests/kernels/ir
- label: Kernels Core Operation Test - label: Kernels Core Operation Test
timeout_in_minutes: 75 timeout_in_minutes: 75
source_file_dependencies: source_file_dependencies:

4
.github/CODEOWNERS vendored
View File

@@ -13,6 +13,9 @@
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy /vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
/vllm/model_executor/model_loader @22quinn /vllm/model_executor/model_loader @22quinn
/vllm/model_executor/layers/batch_invariant.py @yewentao256 /vllm/model_executor/layers/batch_invariant.py @yewentao256
/vllm/ir @ProExpertProg
/vllm/kernels/ @ProExpertProg @tjtanaa
/vllm/kernels/helion @ProExpertProg @zou3519
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa /vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni /vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
CMakeLists.txt @tlrmchlsmth @LucasWilkinson CMakeLists.txt @tlrmchlsmth @LucasWilkinson
@@ -74,6 +77,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche /tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
/tests/evals @mgoin @vadiklyutiy /tests/evals @mgoin @vadiklyutiy
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256 /tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
/tests/kernels/ir @ProExpertProg @tjtanaa
/tests/models @DarkLight1337 @ywang96 /tests/models @DarkLight1337 @ywang96
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche /tests/multimodal @DarkLight1337 @ywang96 @NickLucche
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety /tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety

View File

@@ -8,7 +8,7 @@ from copy import deepcopy
import depyf import depyf
from torch import fx from torch import fx
from torch._ops import OpOverload from torch._ops import OpOverload, OpOverloadPacket
from torch.fx._utils import lazy_format_graph_code from torch.fx._utils import lazy_format_graph_code
from vllm.compilation.passes.fx_utils import find_op_nodes from vllm.compilation.passes.fx_utils import find_op_nodes
@@ -90,7 +90,9 @@ class TestBackend:
# assign by reference, will reflect the final state of the graph # assign by reference, will reflect the final state of the graph
self.final_graph = graph self.final_graph = graph
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True): def check_before_ops(
self, ops: Sequence[OpOverload | OpOverloadPacket], fully_replaced=True
):
for op in ops: for op in ops:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass))) num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
@@ -99,13 +101,19 @@ class TestBackend:
if fully_replaced: if fully_replaced:
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph" assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops: Sequence[OpOverload]): def check_after_ops(self, ops: Sequence[OpOverload | OpOverloadPacket]):
for op in ops: for op in ops:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass))) num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph" assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
def op_count(self, op: OpOverload, before=False) -> int: def op_count(self, op: OpOverload | OpOverloadPacket, before=False) -> int:
graph = self.graph_pre_pass if before else self.graph_post_pass graph = self.graph_pre_pass if before else self.graph_post_pass
return len(list(find_op_nodes(op, graph))) return len(list(find_op_nodes(op, graph)))
def print_graphs(self):
print("=== Graph before custom passes ===")
print(self.graph_pre_pass.python_code(root_module="self", verbose=True).src)
print("=== Graph after custom passes ===")
print(self.graph_post_pass.python_code(root_module="self", verbose=True).src)

View File

@@ -99,6 +99,8 @@ def test_tp1_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","), custom_ops=custom_ops.split(","),
@@ -166,6 +168,7 @@ def test_tp1_fp4_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,

View File

@@ -68,6 +68,7 @@ def test_tp2_ar_rms_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
@@ -128,6 +129,7 @@ def test_tp2_ar_rms_fp4_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
@@ -182,6 +184,7 @@ def test_tp2_ar_rms_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,

View File

@@ -58,6 +58,7 @@ def test_tp2_async_tp_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
@@ -121,6 +122,7 @@ def test_tp2_async_tp_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,

View File

@@ -9,7 +9,6 @@ from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer, multi_gpu_test from tests.utils import TestFP8Layer, multi_gpu_test
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.passes.fx_utils import find_auto_fn
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
@@ -86,13 +85,14 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
] ]
def ops_in_model(self): def ops_in_model(self):
if RMSNorm.enabled(): return (
return [ [torch.ops.vllm_ir.rms_norm]
torch.ops._C.rms_norm.default, + [
torch.ops._C.fused_add_rms_norm.default, torch.ops._C.fused_add_rms_norm.default,
] ]
else: if RMSNorm.enabled()
return [] else []
)
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
@@ -321,4 +321,4 @@ def sequence_parallelism_pass_on_test_model(
assert backend.op_count(op, before=False) == 4 assert backend.op_count(op, before=False) == 4
for op in model.ops_in_model(): for op in model.ops_in_model():
find_auto_fn(backend.graph_post_pass.nodes, op) assert backend.op_count(op, before=False) > 0

View File

View File

@@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from torch import nn
import vllm.kernels # noqa: F401 to register kernels
from vllm import ir
from vllm.compilation.passes.ir.lowering_pass import (
VllmIRLoweringPass,
)
from vllm.config import get_current_vllm_config
from vllm.ir import ops
from vllm.platforms import current_platform
from ...backend import TestBackend
class Model(nn.Module):
def __init__(self, hidden_size=16, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hidden_size = hidden_size
self.weight = torch.ones(hidden_size, dtype=torch.bfloat16)
def forward(self, x):
x1 = x + 4.0
x2 = ops.rms_norm(x1, self.weight, 1e-5)
x3 = x2 * 5.0
# no weight
x4 = ops.rms_norm(x3, None, 1e-5)
x5 = x4 / 2.0
# dispatch to native due to variance_size parameter
x6 = ops.rms_norm(x5, self.weight, 1e-5, self.hidden_size // 2)
return x6 + 3.0
@pytest.mark.parametrize("rms_provider", ops.rms_norm.supported_providers())
def test_lowering_rms_norm(rms_provider, default_vllm_config):
torch.set_default_device(current_platform.device_type)
lowering_pass = VllmIRLoweringPass(get_current_vllm_config())
backend = TestBackend(lowering_pass)
backend_unlowered = TestBackend()
model = Model()
x = torch.randn(8, 16, dtype=torch.bfloat16)
with (
ops.rms_norm.set_priority([rms_provider, "native"]),
ir.enable_torch_wrap(True),
):
compiled_model = torch.compile(model, backend=backend, fullgraph=True)
compiled_unlowered_model = torch.compile(
model, backend=backend_unlowered, fullgraph=True
)
output = compiled_model(x)
output_unlowered = compiled_unlowered_model(x)
selected = lowering_pass.selected_impls["rms_norm"]
assert len(selected) == 3
assert selected["rms_norm"] == rms_provider
assert selected["rms_norm_1"] == rms_provider
assert selected["rms_norm_2"] == "native"
# Compiled function guards on global value, avoid recompilation
with ir.enable_torch_wrap(True):
output2 = compiled_model(x)
torch.testing.assert_close(output_unlowered, output)
torch.testing.assert_close(output_unlowered, output2)

View File

@@ -6,6 +6,7 @@ import pytest
import torch import torch
import vllm.config import vllm.config
import vllm.ir.ops
import vllm.plugins import vllm.plugins
from tests.compile.backend import TestBackend from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer from tests.utils import TestBlockFP8Layer, TestFP8Layer
@@ -51,7 +52,6 @@ from vllm.utils.deep_gemm import (
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Kernel and group_shape combinations: (kernel, group_shape) # Kernel and group_shape combinations: (kernel, group_shape)
@@ -246,10 +246,8 @@ class TestModel(torch.nn.Module):
] ]
def ops_in_model_before_partial(self): def ops_in_model_before_partial(self):
return ( return [torch.ops.vllm_ir.rms_norm] + (
[RMS_OP, RMS_ADD_OP] [RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt]
if self.enable_rms_norm_custom_op
else [torch.ops.aten.rsqrt]
) )
@@ -340,7 +338,10 @@ def test_fusion_rmsnorm_quant(
), ),
) )
with vllm.config.set_current_vllm_config(vllm_config): with (
vllm.config.set_current_vllm_config(vllm_config),
vllm_config.kernel_config.ir_op_priority.set_priority(),
):
# Setup device before model creation # Setup device before model creation
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
@@ -370,8 +371,9 @@ def test_fusion_rmsnorm_quant(
# Hence, we check only 2 add nodes are left (final fused rmsnorm add). # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op: if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) # rms_norm is IR, not included
assert n_add_nodes(backend.graph_pre_pass) == 7 # 6 = 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 6
assert n_add_nodes(backend.graph_post_pass) == 2 assert n_add_nodes(backend.graph_post_pass) == 2

View File

@@ -3,11 +3,11 @@
import pytest import pytest
import torch import torch
from torch._ops import OpOverload, OpOverloadPacket
from tests.compile.backend import TestBackend from tests.compile.backend import TestBackend
from vllm.compilation.passes.fusion.matcher_utils import ( from vllm.compilation.passes.fusion.matcher_utils import (
FLASHINFER_ROTARY_OP, FLASHINFER_ROTARY_OP,
RMS_OP,
ROTARY_OP, ROTARY_OP,
) )
from vllm.compilation.passes.fusion.qk_norm_rope_fusion import ( from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
@@ -100,13 +100,8 @@ class QKNormRoPETestModel(torch.nn.Module):
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
return q, k, v return q, k, v
def ops_in_model_before(self) -> list[torch._ops.OpOverload]: def ops_in_model_before(self) -> list[OpOverload | OpOverloadPacket]:
ops = [] ops: list[OpOverload | OpOverloadPacket] = [torch.ops.vllm_ir.rms_norm]
if self.enable_rms_norm_custom_op:
ops.append(RMS_OP)
else:
ops.append(RSQRT_OP)
if self.enable_rope_custom_op: if self.enable_rope_custom_op:
if self.rotary_emb.use_flashinfer: if self.rotary_emb.use_flashinfer:
ops.append(FLASHINFER_ROTARY_OP) ops.append(FLASHINFER_ROTARY_OP)
@@ -116,7 +111,7 @@ class QKNormRoPETestModel(torch.nn.Module):
ops.append(INDEX_SELECT_OP) ops.append(INDEX_SELECT_OP)
return ops return ops
def ops_in_model_after(self) -> list[torch._ops.OpOverload]: def ops_in_model_after(self) -> list[OpOverload | OpOverloadPacket]:
return [FUSED_QK_ROPE_OP] return [FUSED_QK_ROPE_OP]
@@ -166,7 +161,10 @@ def test_qk_norm_rope_fusion(
num_heads, num_kv_heads, head_dim = 16, 4, 128 num_heads, num_kv_heads, head_dim = 16, 4, 128
T = 5 T = 5
with set_current_vllm_config(vllm_config): with (
set_current_vllm_config(vllm_config),
vllm_config.kernel_config.ir_op_priority.set_priority(),
):
model = QKNormRoPETestModel( model = QKNormRoPETestModel(
num_heads=num_heads, num_heads=num_heads,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,

View File

@@ -1622,3 +1622,26 @@ def fresh_vllm_cache(monkeypatch, use_fresh_inductor_cache):
def enable_pickle(monkeypatch): def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function.""" """`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.fixture(scope="function")
def disable_log_dedup(monkeypatch):
"""
Disable log deduplication such that warning_once and info_once always print.
"""
# Patch logger._print_warning_once to remove the lru_cache decorator
from vllm import logger
original_print_warning_once = logger._print_warning_once
original_print_info_once = logger._print_info_once
original_print_debug_once = logger._print_debug_once
logger._print_warning_once = original_print_warning_once.__wrapped__
logger._print_info_once = original_print_info_once.__wrapped__
logger._print_debug_once = original_print_debug_once.__wrapped__
yield
logger._print_warning_once = original_print_warning_once
logger._print_info_once = original_print_info_once
logger._print_debug_once = original_print_debug_once

View File

@@ -523,3 +523,20 @@ def test_human_readable_model_len():
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]: for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
with pytest.raises(ArgumentError): with pytest.raises(ArgumentError):
parser.parse_args(["--max-model-len", invalid]) parser.parse_args(["--max-model-len", invalid])
def test_ir_op_priority():
from vllm.config.kernel import IrOpPriorityConfig, KernelConfig
ir_op_priority = IrOpPriorityConfig(rms_norm=["vllm_c"])
cfg1 = EngineArgs(ir_op_priority=ir_op_priority).create_engine_config()
cfg2 = EngineArgs(
kernel_config=KernelConfig(ir_op_priority=ir_op_priority)
).create_engine_config()
assert cfg1.kernel_config.ir_op_priority == cfg2.kernel_config.ir_op_priority
with pytest.raises(ValueError, match="rms_norm"):
_ = EngineArgs(
ir_op_priority=ir_op_priority,
kernel_config=KernelConfig(ir_op_priority=ir_op_priority),
).create_engine_config()

497
tests/ir/test_op.py Normal file
View File

@@ -0,0 +1,497 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
import logging
from pathlib import Path
from typing import Any
import pytest
import torch
from torch import fx
from torch.fx.experimental.proxy_tensor import make_fx
import vllm.ir.op
from vllm.ir.op import RESERVED_PROVIDERS, IrOp, IrOpImpl
# This should not exist
assert "_custom_add" not in IrOp.registry
class CustomError(Exception):
pass
@vllm.ir.register_op
def _custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
def test_registration_overloads():
assert all(
n not in IrOp.registry for n in ["_custom_sub", "_custom_mul", "_custom_div"]
)
# Calling with decorator
@vllm.ir.register_op()
def _custom_sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
assert _custom_sub.name == "_custom_sub"
assert _custom_sub is IrOp.registry["_custom_sub"]
# Custom name
@vllm.ir.register_op(name="_custom_mul")
def custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y
assert custom_mul.name == "_custom_mul"
assert custom_mul is IrOp.registry["_custom_mul"]
# Direct construction does not register directly
def _custom_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x / y
custom_div = IrOp("_custom_div", _custom_div)
assert custom_div.name == "_custom_div"
assert "_custom_div" not in IrOp.registry
# Duplicate op registration not allowed
with pytest.raises(AssertionError):
@vllm.ir.register_op
def _custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y - 100
def test_no_kw_only_args():
# kw-only args not supported
with pytest.raises(ValueError, match="keyword-only arguments"):
@vllm.ir.register_op
def _custom_kwarg_op(
x: torch.Tensor, y: torch.Tensor, *, kwarg: int = 0
) -> torch.Tensor:
return x + y + kwarg
assert "_custom_kwarg_op" not in IrOp.registry
class TestIrOpCustomAdd:
# Registration invariants
def test_decorated_object(self):
"""Make sure that referring directly to an op is correct"""
assert isinstance(_custom_add, IrOp)
assert "_custom_add" in IrOp.registry
assert _custom_add is IrOp.registry["_custom_add"]
def test_torch_op_is_registered(self):
assert hasattr(torch.ops.vllm_ir, "_custom_add")
assert callable(torch.ops.vllm_ir._custom_add.default)
# Semantic correctness
def test_semantics_match_native(self):
x = torch.randn(4, 5)
y = torch.randn(4, 5)
# Calls native by default
out = _custom_add(x, y)
ref = x + y
torch.testing.assert_close(out, ref)
# -------------------------
# Implementation registration
# -------------------------
def test_register_impl_is_non_intrusive(self):
@_custom_add.register_impl("dummy_provider")
def dummy_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 123
assert "dummy_provider" in _custom_add.impls
assert isinstance(_custom_add.impls["dummy_provider"], IrOpImpl)
x = torch.ones(2, 2)
y = torch.ones(2, 2)
# Native semantics must still hold
torch.testing.assert_close(_custom_add(x, y), x + y)
def test_schema_contains_tensor_signature(self):
schema = _custom_add._schema_str
assert "Tensor" in schema
assert "-> Tensor" in schema
# -------------------------
# FX visibility
# -------------------------
@pytest.mark.parametrize("enable_torch_wrap", [True, False])
@pytest.mark.parametrize("symbolic_trace", [True, False])
def test_trace_sees_single_custom_op(
self, symbolic_trace: bool, enable_torch_wrap: bool
):
def fn(x, y):
return _custom_add(x, y)
def find_fn(target: Any, gm: fx.GraphModule):
return gm.graph.find_nodes(op="call_function", target=target)
with pytest.raises(CustomError), vllm.ir.enable_torch_wrap(enable_torch_wrap):
if symbolic_trace:
gm = torch.fx.symbolic_trace(fn)
else:
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
x1, y1 = torch.rand(5, 4), torch.rand(5, 4)
out_fx = gm(x1, y1)
out_eager = fn(x1, y1)
# raise error to check enable_torch_wrap context restored correctly
raise CustomError
# check behavior matches eager in all cases
torch.testing.assert_close(out_fx, out_eager)
# check that IR nodes only appear if enable_torch_wrap=True
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
if enable_torch_wrap:
assert len(ir_nodes) == 1, gm.code
else:
assert len(ir_nodes) == 0, gm.code
# with torch wrapping enabled (default), IR nodes appear
if symbolic_trace:
gm = torch.fx.symbolic_trace(fn)
else:
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
assert len(ir_nodes) == 1, gm.code
@_custom_add.register_impl("impl_a")
def impl_a(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 10
@_custom_add.register_impl("impl_b")
def impl_b(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 20
@_custom_add.register_impl("impl_even", supports_args=lambda x, y: x.size(1) % 2 == 0)
def impl_even(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 50
class TestIrOpImplDispatch:
def test_register_impl(self):
assert "impl_a" in _custom_add.impls
impl = _custom_add.impls["impl_a"]
assert impl is impl_a
assert impl.op is _custom_add
assert impl.provider == "impl_a"
assert callable(impl.impl_fn)
# Test duplicate registration rejected
with pytest.raises(AssertionError):
@_custom_add.register_impl("impl_a")
def impl_a_dup(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 30
# Check the original impl is still intact
assert _custom_add.impls["impl_a"] is impl_a
# Check support all args
assert impl_a.supports_all_args
assert impl_b.supports_all_args
assert not impl_even.supports_all_args
def test_reserved_provider_rejected(self):
for provider in RESERVED_PROVIDERS:
with pytest.raises(AssertionError):
@_custom_add.register_impl(provider)
def bad_impl(x, y):
return x + y
def test_set_priority_scoped(self):
assert _custom_add.get_priority() == []
with _custom_add.set_priority(["impl_even", "impl_b"]):
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
# Check nesting
with _custom_add.set_priority(["impl_b"]):
assert _custom_add.get_priority() == ["impl_b"]
# Restored
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
# Check that exception restores priority
with pytest.raises(CustomError), _custom_add.set_priority(["impl_a"]):
assert _custom_add.get_priority() == ["impl_a"]
raise CustomError
# Restored again
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
# Restored to empty
assert _custom_add.get_priority() == []
def test_dispatch_priority_order(self):
x = torch.tensor(1, dtype=torch.int32)
y = torch.tensor(2, dtype=torch.int32)
with _custom_add.set_priority(["impl_b", "impl_a"]):
assert _custom_add.dispatch(x, y) is impl_b
out1 = _custom_add(x, y)
out2 = torch.ops.vllm_ir._custom_add(x, y)
with _custom_add.set_priority(["impl_a"]):
assert _custom_add.dispatch(x, y) is impl_a
out3 = _custom_add(x, y)
out4 = torch.ops.vllm_ir._custom_add(x, y)
# impl_b
assert out1.item() == 1 + 2 + 20
assert out2.item() == 1 + 2 + 20
# impl_a
assert out3.item() == 1 + 2 + 10
assert out4.item() == 1 + 2 + 10
def test_unsupported_impl_filtered(self):
@_custom_add.register_impl("unsupported", supported=False)
def impl_bad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 999
x = torch.tensor(1, dtype=torch.int32)
y = torch.tensor(2, dtype=torch.int32)
with _custom_add.set_priority(["unsupported", "impl_a"]):
assert _custom_add.get_priority() == ["impl_a"]
out = _custom_add(x, y)
# impl_bad skipped → impl_a
assert out.item() == 1 + 2 + 10
def test_supports_args_runtime_dispatch_and_warning(
self, caplog_vllm: pytest.LogCaptureFixture
):
x1 = torch.ones((2, 2), dtype=torch.int32)
y1 = torch.full((2, 2), 2, dtype=torch.int32)
x2 = torch.ones((2, 3), dtype=torch.int32)
y2 = torch.full((2, 3), 2, dtype=torch.int32)
with (
caplog_vllm.at_level(logging.WARNING),
_custom_add.set_priority(["impl_even"]),
):
# Test the warning about native fallback is logged (before even dispatching)
assert len(caplog_vllm.records) == 1
message = caplog_vllm.records[0].message
assert "_custom_add" in message
assert "fallback to native" in message
assert "priority" in message
# Check dispatching
assert _custom_add.get_priority() == ["impl_even", "native"]
assert _custom_add.dispatch(x1, y1) is impl_even
assert _custom_add.dispatch(x2, y2) is _custom_add.impls["native"]
out1 = _custom_add(x1, y1) # size(1) == 2 → impl_even
out2 = _custom_add(x2, y2) # size(1) == 3 → native fallback
# no other warnings
assert len(caplog_vllm.records) == 1
assert torch.all(out1 == 1 + 2 + 50)
assert torch.all(out2 == 1 + 2)
def test_default_priority(
self, caplog_vllm: pytest.LogCaptureFixture, disable_log_dedup
):
# Make sure logs are not deduplicated to properly test the warning
x = torch.tensor([3], dtype=torch.int32)
y = torch.tensor([4], dtype=torch.int32)
# No priority set → falls back to native
assert _custom_add.get_priority() == []
with caplog_vllm.at_level(logging.WARNING):
# Native by default
assert _custom_add.dispatch(x, y) is _custom_add.impls["native"]
out = _custom_add(x, y)
# Check dispatching to native by default
assert out.item() == 3 + 4
# Check warning
assert len(caplog_vllm.records) == 2
message = caplog_vllm.records[0].message.lower()
assert "_custom_add" in message
assert "priority not set" in message
@vllm.ir.register_op
def _custom_mm(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
tmp = x @ y
return tmp if bias is None else tmp + bias
def test_default_args():
# Test that default args are properly applied when dispatching and calling
@_custom_mm.register_impl("impl_mm", supports_args=lambda x, y, bias=None: True)
def impl_mm(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
tmp = x @ y
return tmp + 50 if bias is None else tmp + bias + 100
x1 = torch.tensor([1, 2], dtype=torch.int32)
x2 = torch.tensor([3, 4], dtype=torch.int32)
# Test that supports_args receives the defaulted args
assert impl_mm.supports_args(x1, x2)
with _custom_mm.set_priority(["impl_mm", "native"]):
assert _custom_mm.dispatch(x1, x2) is impl_mm
def test_bad_impl_registrations():
# Check bad schema
with pytest.raises(ValueError, match="does not match native schema"):
@_custom_mm.register_impl("impl_mm_bad_schema")
def impl_mm_bad_schema(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x @ y - 1
with pytest.raises(ValueError, match="does not match native schema"):
@_custom_mm.register_impl("impl_mm_bad_schema_2")
def impl_mm_bad_schema_2(
x: torch.Tensor, y: torch.Tensor, b: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + b - 2
with pytest.raises(ValueError, match="does not match native schema"):
@_custom_mm.register_impl("impl_mm_bad_schema_3")
def impl_mm_bad_schema_3(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return x @ y + bias - 5
# check supports_args with incorrect params
with pytest.raises(ValueError, match="supports_args must be a callable"):
@_custom_mm.register_impl("impl_mm_bad_supports_args", supports_args=True)
def impl_mm_bad_supports_args(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 10
with pytest.raises(ValueError, match="number of parameters"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_2", supports_args=lambda x, y: True
)
def impl_mm_bad_supports_args(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 10
with pytest.raises(ValueError, match="keyword-only parameters"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_3", supports_args=lambda x, y, *, b: True
)
def impl_mm_bad_supports_args_2(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 20
with pytest.raises(ValueError, match="does not match native parameter"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_4", supports_args=lambda x, y, b: True
)
def impl_mm_bad_supports_args_4(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 30
with pytest.raises(ValueError, match="does not match native default"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_5", supports_args=lambda x, y, bias=1: True
)
def impl_mm_bad_supports_args_5(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 40
assert set(_custom_mm.impls.keys()) == {"impl_mm", "native"}
IMPL_OOT_SRC = """
import torch
@_custom_mm.register_impl("impl_mm_oot")
def impl_mm_oot(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y - 99
"""
def load_custom_mm_module(file_path: Path):
spec = importlib.util.spec_from_file_location("_custom_mm_oot", file_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
# Inject the variable into the module's global namespace
# This allows the @_custom_mm.register_impl decorator to work
module._custom_mm = _custom_mm # type: ignore[attr-defined]
# Execute the file; this triggers the decorator
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def test_uuid_and_oot(tmp_path: Path):
file_path = tmp_path / "_custom_mm_oot.py"
file_path.write_text(IMPL_OOT_SRC)
assert "impl_mm_oot" not in _custom_mm.impls
_ = load_custom_mm_module(file_path)
assert "impl_mm_oot" in _custom_mm.impls
uuid = _custom_mm.impls["impl_mm_oot"].uuid()
del _custom_mm.impls["impl_mm_oot"]
# Replace file source
file_path.write_text(IMPL_OOT_SRC + " # added file source")
assert "impl_mm_oot" not in _custom_mm.impls
_ = load_custom_mm_module(file_path)
assert "impl_mm_oot" in _custom_mm.impls
uuid1 = _custom_mm.impls["impl_mm_oot"].uuid()
assert uuid1 != uuid
del _custom_mm.impls["impl_mm_oot"]
# Back to original
file_path.write_text(IMPL_OOT_SRC)
assert "impl_mm_oot" not in _custom_mm.impls
_ = load_custom_mm_module(file_path)
assert "impl_mm_oot" in _custom_mm.impls
uuid2 = _custom_mm.impls["impl_mm_oot"].uuid()
assert uuid2 == uuid
assert uuid2 != uuid1
del _custom_mm.impls["impl_mm_oot"]

View File

@@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
# This registers op implementations
import vllm.kernels # noqa: F401
from tests.kernels.allclose_default import get_default_rtol
from vllm import ir
from vllm.platforms import current_platform
def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype):
x = torch.randn(n_tokens, hidden_size, dtype=dtype)
weight = torch.rand(hidden_size, dtype=dtype)
return x, weight
rms_norm_native = ir.ops.rms_norm.impls["native"].impl_fn
@pytest.mark.skipif(
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
reason="Currently only kernels on CUDA, ROCm and XPU",
)
def test_rms_norm_registration():
expected = {
"native": True,
"vllm_c": current_platform.is_cuda_alike(),
"aiter": current_platform.is_rocm(),
"oink": False,
"xpu_kernels": current_platform.is_xpu(),
}
actual = {
provider: impl.supported for provider, impl in ir.ops.rms_norm.impls.items()
}
assert actual == expected
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("n_tokens", [1, 8, 17])
@pytest.mark.parametrize("hidden_size", [16, 4096, 8192])
@pytest.mark.parametrize("epsilon", [1e-6, 1e-5])
@pytest.mark.skipif(
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
reason="Currently only kernels on CUDA, ROCm and XPU",
)
class TestRMSNorm:
@classmethod
def setup_class(cls, **kwargs):
torch.set_default_device(current_platform.device_type)
def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon):
x, weight = rms_norm_inputs(4, 8, dtype)
out = rms_norm_native(x, weight, epsilon=epsilon)
# Check shape, dtype, device
assert out.shape == x.shape
assert out.dtype == x.dtype
assert out.device == x.device
# Check the scaling property of rms norm
out2 = rms_norm_native(x * 2.0, weight, epsilon=epsilon)
torch.testing.assert_close(out2, out, rtol=get_default_rtol(out), atol=1e-3)
# Check behavior with and without weight
weight1 = torch.ones_like(weight)
out3 = rms_norm_native(x, weight1, epsilon=epsilon)
out4 = rms_norm_native(x, None, epsilon=epsilon)
torch.testing.assert_close(out3, out4)
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels"])
def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider):
impl = ir.ops.rms_norm.impls[provider]
if not impl.supported:
pytest.skip(f"{provider} impl not supported on this platform")
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
args = (x, weight, epsilon, None)
assert impl.supported
if provider == "aiter" and dtype not in [torch.float16, torch.bfloat16]:
assert not impl.supports_args(*args)
return
assert impl.supports_args(*args)
out_impl = impl.impl_fn(*args)
out_native = rms_norm_native(*args)
torch.testing.assert_close(
out_impl, out_native, rtol=get_default_rtol(out_impl), atol=1e-3
)
# check that dispatched call matches direct call
with ir.ops.rms_norm.set_priority([provider, "native"]):
out_impl2 = ir.ops.rms_norm(*args)
# exact match
torch.testing.assert_close(out_impl2, out_impl, rtol=0.0, atol=0.0)
# none of these support variance_size override
assert not impl.supports_args(x, weight, epsilon, 4)
assert not impl.supports_args(x, weight, epsilon, variance_size=4)
# test weight=None behavior
out_impl_no_weight = impl.impl_fn(x, None, epsilon)
out_impl_unit_weight = impl.impl_fn(x, torch.ones_like(weight), epsilon)
torch.testing.assert_close(
out_impl_no_weight,
out_impl_unit_weight,
rtol=get_default_rtol(out_impl_no_weight),
atol=2e-4,
)
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels", "native"])
def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider):
if not ir.ops.rms_norm.impls[provider].supported:
pytest.skip(f"{provider} impl not supported on this platform")
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
args = (x, weight, epsilon, None)
# When checking the torch op, we have to set priority and use dispatch
with ir.ops.rms_norm.set_priority([provider, "native"]):
torch.library.opcheck(torch.ops.vllm_ir.rms_norm, args)

View File

@@ -27,7 +27,6 @@ from vllm.model_executor.layers.layernorm import (
RMSNorm, RMSNorm,
dispatch_rocm_rmsnorm_func, dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, fused_add_rms_norm,
rms_norm,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -156,7 +155,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
assert topk_func == vllm_topk_sigmoid assert topk_func == vllm_topk_sigmoid
@pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("add_residual", [False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", [True, False])
@pytest.mark.skipif( @pytest.mark.skipif(
@@ -165,7 +164,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
def test_rms_norm_dispatch( def test_rms_norm_dispatch(
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
): ):
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter) rms_norm_func = dispatch_rocm_rmsnorm_func(dtype, use_rocm_aiter)
should_use_rocm_aiter = ( should_use_rocm_aiter = (
current_platform.is_rocm() current_platform.is_rocm()
@@ -173,11 +172,7 @@ def test_rms_norm_dispatch(
and dtype in RMS_NORM_SUPPORTED_DTYPES and dtype in RMS_NORM_SUPPORTED_DTYPES
) )
if add_residual and should_use_rocm_aiter: if should_use_rocm_aiter:
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
elif should_use_rocm_aiter:
assert rms_norm_func == rocm_aiter_ops.rms_norm
elif add_residual:
assert rms_norm_func == fused_add_rms_norm
else: else:
assert rms_norm_func == rms_norm assert rms_norm_func == fused_add_rms_norm

View File

@@ -6,12 +6,14 @@ import os
from dataclasses import MISSING, Field, asdict, dataclass, field from dataclasses import MISSING, Field, asdict, dataclass, field
from unittest.mock import patch from unittest.mock import patch
import pydantic
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
KernelConfig,
ModelConfig, ModelConfig,
ParallelConfig, ParallelConfig,
PoolerConfig, PoolerConfig,
@@ -21,6 +23,7 @@ from vllm.config import (
update_config, update_config,
) )
from vllm.config.compilation import CompilationMode, CUDAGraphMode from vllm.config.compilation import CompilationMode, CUDAGraphMode
from vllm.config.kernel import IrOpPriorityConfig
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.config.utils import get_field from vllm.config.utils import get_field
from vllm.config.vllm import ( from vllm.config.vllm import (
@@ -1077,6 +1080,39 @@ def test_vllm_config_explicit_overrides():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
def test_fusion_pass_op_priority():
"""This test checks that custom op enablement & IR op priority
correctly control default fusions"""
# Default config, O2, rms_norm+quant fusion disabled
cfg1 = VllmConfig()
assert not cfg1.compilation_config.pass_config.fuse_norm_quant
# rms_norm manually enabled, O1, rms_norm+quant fusion enabled
cfg2 = VllmConfig(
optimization_level=OptimizationLevel.O1,
compilation_config=CompilationConfig(
custom_ops=["+rms_norm"],
),
)
assert cfg2.compilation_config.pass_config.fuse_norm_quant
# using custom kernel for RMSNorm via IR:
# Note that vLLM IR only supports the non-residual rms_norm for now;
# soon this will be resolved.
cfg3 = VllmConfig(
kernel_config=KernelConfig(
ir_op_priority=IrOpPriorityConfig(rms_norm=["vllm_c"])
)
)
assert cfg3.compilation_config.pass_config.fuse_norm_quant
# block-fp8 model should enable quant_fp8 automatically
cfg4 = VllmConfig(model_config=ModelConfig("Qwen/Qwen3-4B-FP8"))
assert "+quant_fp8" in cfg4.compilation_config.custom_ops
assert cfg4.compilation_config.pass_config.fuse_norm_quant
def test_scheduler_config_init(): def test_scheduler_config_init():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
# Positional InitVars missing # Positional InitVars missing
@@ -1171,3 +1207,35 @@ def test_eagle_draft_model_config():
assert draft_model_config.hf_text_config.model_type == "eagle" assert draft_model_config.hf_text_config.model_type == "eagle"
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"] assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.architecture == "EagleLlamaForCausalLM" assert draft_model_config.architecture == "EagleLlamaForCausalLM"
def test_ir_op_priority_default():
"""Test that IR op priority defaults are set correctly."""
from vllm.config.kernel import IrOpPriorityConfig
# Assert default is applied to ops
priority_config = IrOpPriorityConfig.with_default(["vllm_c", "native"])
assert priority_config.rms_norm == ["vllm_c", "native"]
# Assert single ops override the default
assert IrOpPriorityConfig.with_default(
["vllm_c", "native"], rms_norm=["oink", "native"]
) == IrOpPriorityConfig(rms_norm=["oink", "native"])
def test_ir_op_priority_str():
"""Test that passing a comma-delimited string works"""
from vllm.config.kernel import IrOpPriorityConfig
priority_config = IrOpPriorityConfig(rms_norm="vllm_c")
assert priority_config.rms_norm == ["vllm_c"]
priority_config = IrOpPriorityConfig(rms_norm="vllm_c,native")
assert priority_config.rms_norm == ["vllm_c", "native"]
priority_config = IrOpPriorityConfig(rms_norm=" native, vllm_c ")
assert priority_config.rms_norm == ["native", "vllm_c"]
with pytest.raises(pydantic.ValidationError):
# must be list of only strings
priority_config = IrOpPriorityConfig(rms_norm=["vllm_c", 4, "native"])

View File

@@ -3,6 +3,7 @@
import contextlib import contextlib
from importlib.util import find_spec from importlib.util import find_spec
from types import ModuleType from types import ModuleType
from typing import Any
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
@@ -10,6 +11,7 @@ import torch.fx as fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
import vllm.ir.ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.utils import Range from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
@@ -28,7 +30,7 @@ from vllm.utils.torch_utils import (
from ..inductor_pass import enable_fake_mode from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@@ -258,6 +260,12 @@ class BasePattern:
self.tp = get_tp_group() self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
class AllReduceRMSNormPattern(BasePattern): class AllReduceRMSNormPattern(BasePattern):
""" """
@@ -276,20 +284,17 @@ class AllReduceRMSNormPattern(BasePattern):
super().__init__(dtype, device) super().__init__(dtype, device)
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs() # input, weight
return [self.empty(5, 16), self.empty(16)]
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass) -> None: def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, weight: torch.Tensor input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input) allreduce_output = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(allreduce_output, weight) rms = vllm.ir.ops.rms_norm(allreduce_output, weight, self.epsilon)
return rms, allreduce_output return rms, allreduce_output
@@ -407,15 +412,13 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs() _, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit # input, weight
return [input.to(self.dtype), weight, scale] return [self.empty(5, 16), self.empty(16), scale]
def register(self, pm_pass: PatternMatcherPass) -> None: def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
@@ -424,7 +427,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
scale: torch.Tensor, scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input) all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight) rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
quant, _ = self.quant_matcher(rms, scale) quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce return quant, all_reduce
@@ -553,7 +556,6 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
super().__init__(dtype, device) super().__init__(dtype, device)
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
@@ -575,7 +577,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
output_scale: torch.Tensor, output_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input) all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight) rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
quant_out_tuple = auto_functionalized( quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP, STATIC_FP4_QUANT_OP,
input=rms, input=rms,

View File

@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
ROTARY_OP = torch.ops._C.rotary_embedding.default ROTARY_OP = torch.ops._C.rotary_embedding.default
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
@@ -160,69 +159,6 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
return result return result
class MatcherRMSNorm(MatcherCustomOp):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self._rmsnorm_op = RMS_OP
self.match_rocm_aiter = match_rocm_aiter
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return self._rmsnorm_op(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight)
result = torch.empty_like(input)
_, result = auto_functionalized(
self._rmsnorm_op,
result=result,
input=input,
weight=weight,
epsilon=self.epsilon,
)
return result
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight
)
class MatcherFusedAddRMSNorm(MatcherCustomOp): class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__( def __init__(
self, self,

View File

@@ -10,6 +10,7 @@ from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
import vllm.ir.ops
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
@@ -17,7 +18,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from ..inductor_pass import enable_fake_mode from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding from .matcher_utils import MatcherRotaryEmbedding
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64 from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -64,7 +65,6 @@ class QkNormRopePattern:
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps self.eps = eps
self.rmsnorm_matcher = MatcherRMSNorm(eps)
self.is_neox = is_neox self.is_neox = is_neox
self.rope_flashinfer = rope_flashinfer self.rope_flashinfer = rope_flashinfer
self.rope_matcher = MatcherRotaryEmbedding( self.rope_matcher = MatcherRotaryEmbedding(
@@ -129,14 +129,14 @@ class QkNormRopePattern:
q_by_head = q.view( q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim *q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
) )
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight) q_normed_by_head = vllm.ir.ops.rms_norm(q_by_head, q_weight, self.eps)
q_flat = q_normed_by_head.view(q.shape) q_flat = q_normed_by_head.view(q.shape)
# K path: view -> RMS -> view back to k.shape # K path: view -> RMS -> view back to k.shape
k_by_head = k.view( k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim *k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
) )
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight) k_normed_by_head = vllm.ir.ops.rms_norm(k_by_head, k_weight, self.eps)
k_flat = k_normed_by_head.view(k.shape) k_flat = k_normed_by_head.view(k.shape)
# RoPE: apply to flattened q/k # RoPE: apply to flattened q/k

View File

@@ -9,6 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload from torch._ops import OpOverload
import vllm.ir.ops
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -30,7 +31,6 @@ from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import ( from .matcher_utils import (
MatcherFusedAddRMSNorm, MatcherFusedAddRMSNorm,
MatcherQuantFP8, MatcherQuantFP8,
MatcherRMSNorm,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -54,7 +54,6 @@ def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda") return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = { QUANT_OPS: dict[QuantKey, OpOverload] = {
@@ -131,11 +130,9 @@ class RMSNormQuantPattern:
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key] self.FUSED_OP = FUSED_OPS[key]
self.rmsnorm_matcher = ( if key.fused_add:
MatcherRMSNorm(epsilon) self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8( self.quant_matcher = MatcherQuantFP8(
key.quant, key.quant,
has_col_major_scales=has_col_major_scales, has_col_major_scales=has_col_major_scales,
@@ -161,16 +158,12 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def pattern( def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
return self.quant_matcher(result_rms, scale)[0] return self.quant_matcher(result_rms, scale)[0]
def replacement( def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty( result = torch.empty(
input.shape, device=input.device, dtype=self.quant_dtype input.shape, device=input.device, dtype=self.quant_dtype
) )
@@ -187,8 +180,8 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1] return at[1]
inputs = [ inputs = [
# input, weight empty_bf16(5, 16), # input
*self.rmsnorm_matcher.inputs(), empty_bf16(16), # weight
self.quant_matcher.inputs()[1], # scale self.quant_matcher.inputs()[1], # scale
] ]
pattern(*inputs) pattern(*inputs)
@@ -391,7 +384,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
def pattern( def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
result = torch.empty( result = torch.empty(
result_rms.shape, result_rms.shape,
device=result_rms.device, device=result_rms.device,
@@ -442,12 +435,14 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
# result, scale # result, scale
return at[1], at[2] return at[1], at[2]
scale = self.quant_matcher.empty_f32(1, 1)
pm.register_replacement( pm.register_replacement(
pattern, pattern,
replacement, replacement,
self.rmsnorm_matcher.inputs() + [scale], [
empty_bf16(5, 16), # input
empty_bf16(16), # weight
self.quant_matcher.empty_f32(1, 1), # scale
],
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
) )
@@ -472,7 +467,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def pattern( def pattern(
input: torch.Tensor, weight: torch.Tensor input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
# result, scale # result, scale
return self.quant_matcher(result_rms) # type: ignore[no-any-return] return self.quant_matcher(result_rms) # type: ignore[no-any-return]
@@ -502,7 +497,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
pm.register_replacement( pm.register_replacement(
pattern, pattern,
replacement, replacement,
self.rmsnorm_matcher.inputs(), [
empty_bf16(5, 16), # input
empty_bf16(16), # weight
],
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
) )

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
@@ -24,7 +25,6 @@ from .act_quant_fusion import ActivationQuantPattern
from .matcher_utils import ( from .matcher_utils import (
MatcherFusedAddRMSNorm, MatcherFusedAddRMSNorm,
MatcherQuantFP8, MatcherQuantFP8,
MatcherRMSNorm,
MatcherSiluAndMul, MatcherSiluAndMul,
) )
from .rms_quant_fusion import ( from .rms_quant_fusion import (
@@ -41,17 +41,23 @@ class AiterRMSNormQuantPattern:
): ):
self.epsilon = epsilon self.epsilon = epsilon
self.quant_dtype = key.quant.dtype self.quant_dtype = key.quant.dtype
self.device = torch.device("cuda")
self.rmsnorm_matcher = ( if key.fused_add:
MatcherRMSNorm(epsilon, match_rocm_aiter=True) self.rmsnorm_matcher = MatcherFusedAddRMSNorm(
if not key.fused_add epsilon, match_rocm_aiter=True
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True) )
)
self.quant_matcher = MatcherQuantFP8( self.quant_matcher = MatcherQuantFP8(
key.quant, key.quant,
match_rocm_aiter=match_aiter_quant, match_rocm_aiter=match_aiter_quant,
) )
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.bfloat16, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
"""AITER RMSNorm + Dynamic Quantization pattern.""" """AITER RMSNorm + Dynamic Quantization pattern."""
@@ -79,7 +85,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon)
result, scale = self.quant_matcher(result_rms) result, scale = self.quant_matcher(result_rms)
return result, scale return result, scale
@@ -99,7 +105,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
pm.register_replacement( pm.register_replacement(
pattern, pattern,
replacement, replacement,
self.rmsnorm_matcher.inputs(), # input, weight
[self.empty(5, 16), self.empty(16)],
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
) )
@@ -188,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight) result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon)
result, scale = self.quant_matcher(result_rms) result, scale = self.quant_matcher(result_rms)
return result, scale return result, scale
@@ -206,7 +213,12 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
return at[0], at[1] return at[0], at[1]
pm.register_replacement( pm.register_replacement(
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass pattern,
replacement,
# input, weight
[self.empty(5, 16), self.empty(16)],
pm.fwd_only,
pm_pass,
) )

View File

@@ -10,6 +10,7 @@ import torch._inductor.pattern_matcher as pm
import torch.fx as fx import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
import vllm.ir.ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.utils import Range from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
@@ -22,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from ..inductor_pass import enable_fake_mode from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass from ..utility.noop_elimination import NoOpEliminationPass
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -122,35 +123,38 @@ class _SequenceParallelPatternHelper:
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
) )
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None: def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) # input, weight
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) return [self.empty([1, 8, 4]), self.empty([4])]
return [input, arg3_1]
def register(self, pm_pass: PatternMatcherPass) -> None: def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
arg3_1: torch.Tensor, weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input) all_reduce = self._all_reduce(input)
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1) rmsnorm = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
return rmsnorm, all_reduce return rmsnorm, all_reduce
def replacement( def replacement(
input: torch.Tensor, input: torch.Tensor,
arg3_1: torch.Tensor, weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input) reduce_scatter = self._reduce_scatter(input)
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1) rmsnorm = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
all_gather = self._all_gather(rmsnorm) all_gather = self._all_gather(rmsnorm)
return all_gather, reduce_scatter return all_gather, reduce_scatter
@@ -222,14 +226,11 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
device: str | None, device: str | None,
) -> None: ) -> None:
super().__init__(epsilon, dtype, device) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) # input, weight, scale
weight = torch.empty([4], device=self.device, dtype=self.dtype) return [self.empty([1, 8, 4]), self.empty([4]), self.empty_f32([1, 1])]
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None: def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
@@ -238,7 +239,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
scale: torch.Tensor, scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input) all_reduce = self._all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight) rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
quant, _ = self.quant_matcher(rms, scale) quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce return quant, all_reduce
@@ -248,7 +249,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
scale: torch.Tensor, scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input) reduce_scatter = self._reduce_scatter(input)
rms = self.rmsnorm_matcher(reduce_scatter, weight) rms = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
quant, _ = self.quant_matcher(rms, scale) quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant) all_gather = self._all_gather(quant)

View File

View File

@@ -0,0 +1,158 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable
from torch import fx
from torch._inductor.pattern_matcher import (
CallFunctionVarArgs,
Match,
PatternMatcherPass,
register_graph_pattern,
)
from torch._ops import OpOverload, OpOverloadPacket
from vllm.config import VllmConfig
from vllm.ir.op import IrOp
from vllm.logger import init_logger
from vllm.logging_utils import lazy
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
def get_default_overload(op: OpOverload | OpOverloadPacket) -> OpOverload:
if isinstance(op, OpOverloadPacket):
return op.default
assert isinstance(op, OpOverload), "Expected an OpOverload or OpOverloadPacket"
return op
def get_ir_op(node: fx.Node) -> IrOp | None:
if node.op != "call_function":
return None
if not isinstance(node.target, (OpOverload, OpOverloadPacket)):
return None
op_overload = get_default_overload(node.target)
if op_overload.namespace != "vllm_ir":
return None
op_name = op_overload._opname
if op_name not in IrOp.registry:
logger.warning(
"Unknown vLLM IR op %s, there's likely an issue with torch registration, "
"or a torch custom op was registered in the vllm_ir namespace by mistake.",
op_name,
)
return None
ir_op = IrOp.registry[op_name]
return ir_op
class VllmIRLoweringPass(VllmInductorPass):
"""
This pass lowers vLLM IR ops to their implementations the priority list.
"""
def __init__(self, vllm_config: VllmConfig) -> None:
super().__init__(vllm_config)
self.patterns = PatternMatcherPass(self.pass_name)
self.selected_impls: dict[str, dict[str, str]] = defaultdict(lambda: {})
self.ops = [ir_op.torch_op for ir_op in IrOp.registry.values()]
# Look for any call_function node where the target is a vLLM IR op.
# Then, lower_matched_op will select, trace, and insert the implementation.
register_graph_pattern(
CallFunctionVarArgs(self.ops),
pass_dict=self.patterns,
)(self.lower_matched_op)
def lower_matched_op(self, match: Match, *args, **kwargs):
# TODO(luka) I think args and kwargs are for the match, but just use the node?
assert len(match.nodes) == 1, "Expected single node match"
node = match.nodes[0]
ir_op = get_ir_op(node)
assert ir_op is not None, "Expected vLLM IR op"
assert not node.kwargs # I think there should never be kwargs here
# Select and record the implementation, using fake args
fake_args = fx.map_arg(node.args, lambda arg: arg.meta["val"])
ir_op_impl = ir_op.dispatch(*fake_args)
self.selected_impls[ir_op.name][node.name] = ir_op_impl.provider
# replace_by_example wants node args, not the fake tensors
# TODO(luka): Use aot_export_module to get functionalized graph
# TODO(luka): Cache the fx_replacement to avoid re-tracing the same impl
# Defaults not present on node.args but required for replacement tracing
bound_args = ir_op._py_signature.bind(*node.args)
bound_args.apply_defaults()
match.replace_by_example(ir_op_impl.impl_fn, bound_args.args)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
# clear at the beginning instead of end, so that tests can inspect
self.selected_impls.clear()
count = self.patterns.apply(graph)
logger.debug("VllmIRLoweringPass lowered %d vLLM IR nodes", count)
# TODO write self.selected_impls to depyf/tlparse dir
def count_items(impls: Iterable[str]) -> dict[str, int]:
counts: dict[str, int] = defaultdict(lambda: 0)
for impl in impls:
counts[impl] += 1
return counts
def print_count(counts: dict[str, int]) -> str:
# e.g., "impl1*3,impl2"
impl_count = lambda i, c: f"{i}" if c == 1 else f"{i}*{c}"
return ",".join(impl_count(impl, count) for impl, count in counts.items())
logger.debug(
"Selected implementations: %s",
lazy(
lambda: ", ".join(
f"{op}={print_count(count_items(impls_by_node.values()))}"
for op, impls_by_node in self.selected_impls.items()
)
),
)
failed_nodes: list[fx.Node] = []
failed_ops: set[str] = set()
# Check no vllm_ir nodes were left in the graph
for node in graph.nodes:
if (ir_op := get_ir_op(node)) is None:
continue
failed_nodes.append(node)
failed_ops.add(ir_op.name)
if failed_nodes or failed_ops:
logger.warning("Failed to lower vLLM IR ops: %s", ",".join(failed_ops))
logger.warning("Full node list: %s", failed_nodes)
def uuid(self) -> str:
"""
IR op priority & impl sources affect lowering pass output,
so we include them in the cache key.
"""
priorities = {name: op.get_priority() for name, op in IrOp.registry.items()}
priorities_str = ";".join(
f"{name}={','.join(p)}" for name, p in priorities.items()
)
impl_uuids_str = ";".join(
f"{name}={
','.join(IrOp.registry[name].impls[provider].uuid() for provider in p)
}"
for name, p in priorities.items()
)
return f"{super().uuid()}|{priorities_str}|{impl_uuids_str}"

View File

@@ -14,6 +14,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var from vllm.utils.system_utils import set_env_var
from .ir.lowering_pass import VllmIRLoweringPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
if rocm_aiter_ops.is_enabled(): if rocm_aiter_ops.is_enabled():
@@ -99,8 +100,17 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
else: else:
logger.debug("Skipping %s with compile range %s", pass_, compile_range) logger.debug("Skipping %s with compile range %s", pass_, compile_range)
# post-cleanup goes before fix_functionalization # perform the first post-cleanup before IR lowering to clean up fusion artifacts
# because it requires a functional graph # and make sure no dead IR ops are lowered.
self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1
# lowering before cleanup so DCE can clean up lowered ops.
# DCE handles mutating ops correctly as well.
self.ir_lowering(graph)
VllmInductorPass.dump_prefix += 1
# clean up after lowering again
self.post_cleanup(graph) self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1 VllmInductorPass.dump_prefix += 1
@@ -152,7 +162,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
self.passes += [SplitCoalescingPass(config)] self.passes += [SplitCoalescingPass(config)]
self.passes += [QKNormRoPEFusionPass(config)] self.passes += [QKNormRoPEFusionPass(config)]
# needs a functional graph self.ir_lowering = VllmIRLoweringPass(config)
self.post_cleanup = PostCleanupPass(config) self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config) self.fix_functionalization = FixFunctionalizationPass(config)
@@ -171,6 +181,10 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()} state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
for pass_ in self.passes: for pass_ in self.passes:
passes.append(pass_.uuid()) passes.append(pass_.uuid())
passes.append(self.post_cleanup.uuid())
passes.append(self.ir_lowering.uuid())
passes.append(self.post_cleanup.uuid())
passes.append(self.fix_functionalization.uuid()) passes.append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor # Include the compile range in the uuid to ensure that inductor

View File

@@ -152,6 +152,7 @@ class VllmPatternMatcherPass(VllmInductorPass):
f"auto_functionalized as auto_functionalized\n" f"auto_functionalized as auto_functionalized\n"
f"from torch._inductor.pattern_matcher import *\n" f"from torch._inductor.pattern_matcher import *\n"
f"vllm = torch.ops.vllm", f"vllm = torch.ops.vllm",
"vllm_ir = torch.ops.vllm_ir",
file=f, file=f,
) )

View File

@@ -466,6 +466,15 @@ class CompilationConfig:
disabled when running with Inductor: mode>CompilationMode.NONE and disabled when running with Inductor: mode>CompilationMode.NONE and
backend="inductor". backend="inductor".
Inductor generates (fused) Triton kernels for disabled custom ops.""" Inductor generates (fused) Triton kernels for disabled custom ops."""
ir_enable_torch_wrap: bool = None # type: ignore[assignment]
"""If True, enable vllm_ir torch custom op wrapping during the forward pass.
When False, torch custom op wrapping is disabled, allowing Dynamo to trace the
selected implementation directly or avoiding torch custom op overhead in eager mode.
Defaults to True when using Inductor with vllm-compile
(backend=="inductor" and mode == VLLM_COMPILE), False otherwise.
"""
splitting_ops: list[str] | None = None splitting_ops: list[str] | None = None
"""A list of ops to exclude from cudagraphs, used in piecewise compilation. """A list of ops to exclude from cudagraphs, used in piecewise compilation.
@@ -830,6 +839,7 @@ class CompilationConfig:
"cudagraph_mode", "cudagraph_mode",
"max_cudagraph_capture_size", "max_cudagraph_capture_size",
"use_inductor_graph_partition", "use_inductor_graph_partition",
"ir_enable_torch_wrap",
mode="wrap", mode="wrap",
) )
@classmethod @classmethod

View File

@@ -1,13 +1,106 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Literal from dataclasses import asdict, fields
from typing import TYPE_CHECKING, Any, Literal
from pydantic import field_validator from pydantic import Field, field_validator
from vllm.config.utils import config, get_hash_factors, hash_factors
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
@config
class IrOpPriorityConfig:
"""
Configuration for vLLM IR op priority for dispatching/lowering during the
forward pass. Each member is a list of strings, which will be passed to
vllm.ir.ops.<op_name>.set_priority() for the duration of the forward pass.
A single comma-separated string is accepted as well,
If specified manually, platform defaults will be appended to the lists.
See KernelConfig.set_platform_defaults().
"""
rms_norm: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.rms_norm"""
def compute_hash(self) -> str:
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Any future fields that don't affect compilation should be excluded.
Also, manually add IR op impl UUIDs to make sure they affect the compile cache.
"""
factors = get_hash_factors(self, set())
# Implementations are hidden from Dynamo,
# so they don't show up in the traced files list.
from vllm.ir.op import IrOp
assert "_impls" not in factors
factors["_impls"] = {
name: {
provider: IrOp.registry[name].impls[provider].uuid() for provider in p
}
for name, p in asdict(self).items()
}
return hash_factors(factors)
@field_validator("*", mode="before")
@classmethod
def _to_list_str(cls, value: str | list[str]):
if isinstance(value, str):
value = value.replace(" ", "").split(",")
assert all(isinstance(v, str) for v in value)
return value
@contextlib.contextmanager
def set_priority(self):
"""
Context manager to set the IR op priority for all op members.
It also imports vllm.kernels to ensure all implementations are made available.
"""
import vllm.kernels # noqa: F401, registers IR op implementations
from vllm.ir.op import IrOp
with contextlib.ExitStack() as stack:
for field in fields(self):
op_priority = getattr(self, field.name)
assert op_priority is not None, (
f"IR op priority for {field.name} must be set"
)
logger.debug(
"Setting IR op priority for %s to %s", field.name, op_priority
)
ir_op = IrOp.registry[field.name]
stack.enter_context(ir_op.set_priority(op_priority))
yield
@classmethod
def with_default(
cls, default: list[str], /, **kwargs: list[str]
) -> "IrOpPriorityConfig":
"""
A helper to create an IrOpPriorityConfig where fields not specified in kwargs
use the given default list.
"""
for field in fields(cls):
if field.name not in kwargs:
kwargs[field.name] = list(default)
return cls(**kwargs)
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
MoEBackend = Literal[ MoEBackend = Literal[
"auto", "auto",
@@ -26,6 +119,12 @@ MoEBackend = Literal[
class KernelConfig: class KernelConfig:
"""Configuration for kernel selection and warmup behavior.""" """Configuration for kernel selection and warmup behavior."""
ir_op_priority: IrOpPriorityConfig = Field(default_factory=IrOpPriorityConfig)
"""
vLLM IR op priority for dispatching/lowering during the forward pass.
Platform defaults appended automatically during VllmConfig.__post_init__.
"""
enable_flashinfer_autotune: bool = None # type: ignore[assignment] enable_flashinfer_autotune: bool = None # type: ignore[assignment]
"""If True, run FlashInfer autotuning during kernel warmup.""" """If True, run FlashInfer autotuning during kernel warmup."""
@@ -51,21 +150,17 @@ class KernelConfig:
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, Produces a hash unique to the pass configuration.
ensure that it is included in the factors list if Any new fields that affect compilation should be added to the hash.
it affects the computation graph. Any future fields that don't affect compilation should be excluded.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
""" """
# no factors to consider. ignored_factors = {
# this config will not affect the computation graph. "enable_flashinfer_autotune",
factors: list[Any] = [] "ir_op_priority", # handled separately below
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() }
return hash_str factors = get_hash_factors(self, ignored_factors)
factors["ir_op_priority"] = self.ir_op_priority.compute_hash()
return hash_factors(factors)
@field_validator("enable_flashinfer_autotune", mode="wrap") @field_validator("enable_flashinfer_autotune", mode="wrap")
@classmethod @classmethod
@@ -74,3 +169,31 @@ class KernelConfig:
if value is None: if value is None:
return value return value
return handler(value) return handler(value)
def set_platform_defaults(self, vllm_config: "VllmConfig") -> None:
"""Set platform-specific defaults for the kernel config."""
from vllm.platforms import current_platform
platform_op_priority = current_platform.get_default_ir_op_priority(vllm_config)
logger.debug(
"Setting platform-specific IR op priority defaults: %s, user-defined: %s",
platform_op_priority,
self.ir_op_priority,
)
for op_name, op_priority in asdict(platform_op_priority).items():
current_op_priority: list[str] = getattr(self.ir_op_priority, op_name)
if current_op_priority is None:
setattr(self.ir_op_priority, op_name, op_priority)
else:
# Append platform-specific priorities
# Must be idempotent because vllm_config.set_platform_defaults() may be
# called multiple times (due to VllmConfig.__post_init__ manual call).
unique_op_priority = [
op for op in op_priority if op not in current_op_priority
]
current_op_priority.extend(unique_op_priority)
logger.info(
"Final IR op priority after setting platform defaults: %s",
self.ir_op_priority,
)

View File

@@ -95,9 +95,11 @@ def enable_norm_fusion(cfg: "VllmConfig") -> bool:
"""Enable if either RMS norm or quant FP8 custom op is active; """Enable if either RMS norm or quant FP8 custom op is active;
otherwise Inductor handles fusion.""" otherwise Inductor handles fusion."""
return cfg.compilation_config.is_custom_op_enabled( return (
"rms_norm" cfg.compilation_config.is_custom_op_enabled("rms_norm")
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8") or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
or cfg.kernel_config.ir_op_priority.rms_norm[0] != "native"
)
def enable_act_fusion(cfg: "VllmConfig") -> bool: def enable_act_fusion(cfg: "VllmConfig") -> bool:
@@ -417,6 +419,10 @@ class VllmConfig:
vllm_factors.append(self.compilation_config.compute_hash()) vllm_factors.append(self.compilation_config.compute_hash())
else: else:
vllm_factors.append("None") vllm_factors.append("None")
if self.kernel_config:
vllm_factors.append(self.kernel_config.compute_hash())
else:
vllm_factors.append(None)
if self.kv_transfer_config: if self.kv_transfer_config:
vllm_factors.append(self.kv_transfer_config.compute_hash()) vllm_factors.append(self.kv_transfer_config.compute_hash())
else: else:
@@ -890,6 +896,13 @@ class VllmConfig:
else: else:
self.compilation_config.mode = CompilationMode.NONE self.compilation_config.mode = CompilationMode.NONE
# By default, enable torch wrapping only when using custom Inductor lowering
if self.compilation_config.ir_enable_torch_wrap is None:
self.compilation_config.ir_enable_torch_wrap = (
self.compilation_config.mode == CompilationMode.VLLM_COMPILE
and self.compilation_config.backend == "inductor"
)
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")): if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
if ( if (
self.compilation_config.backend == "inductor" self.compilation_config.backend == "inductor"
@@ -899,6 +912,11 @@ class VllmConfig:
else: else:
self.compilation_config.custom_ops.append("all") self.compilation_config.custom_ops.append("all")
# This populates IR op priorities,
# must happen after compilation mode and backend are decided,
# but before fusion defaults are applied as those may depend on op priority.
self.kernel_config.set_platform_defaults(self)
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level] default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
self._apply_optimization_level_defaults(default_config) self._apply_optimization_level_defaults(default_config)
if self.kernel_config.enable_flashinfer_autotune is None: if self.kernel_config.enable_flashinfer_autotune is None:
@@ -1706,7 +1724,8 @@ class VllmConfig:
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa
f"pooler_config={self.model_config.pooler_config!r}, " f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}" f"compilation_config={self.compilation_config!r}, "
f"kernel_config={self.kernel_config!r}"
) )
def validate_block_size(self) -> None: def validate_block_size(self) -> None:

View File

@@ -8,7 +8,7 @@ import functools
import json import json
import sys import sys
from collections.abc import Callable from collections.abc import Callable
from dataclasses import MISSING, dataclass, fields, is_dataclass from dataclasses import MISSING, asdict, dataclass, fields, is_dataclass
from itertools import permutations from itertools import permutations
from types import UnionType from types import UnionType
from typing import ( from typing import (
@@ -70,7 +70,7 @@ from vllm.config.cache import (
PrefixCachingHashAlgo, PrefixCachingHashAlgo,
) )
from vllm.config.device import Device from vllm.config.device import Device
from vllm.config.kernel import MoEBackend from vllm.config.kernel import IrOpPriorityConfig, MoEBackend
from vllm.config.lora import MaxLoRARanks from vllm.config.lora import MaxLoRARanks
from vllm.config.model import ( from vllm.config.model import (
ConvertOption, ConvertOption,
@@ -401,6 +401,7 @@ class EngineArgs:
max_cudagraph_capture_size: int | None = get_field( max_cudagraph_capture_size: int | None = get_field(
CompilationConfig, "max_cudagraph_capture_size" CompilationConfig, "max_cudagraph_capture_size"
) )
ir_op_priority: IrOpPriorityConfig = get_field(KernelConfig, "ir_op_priority")
# Note: Specifying a custom executor backend by passing a class # Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
# notice. # notice.
@@ -657,6 +658,9 @@ class EngineArgs:
self.weight_transfer_config = WeightTransferConfig( self.weight_transfer_config = WeightTransferConfig(
**self.weight_transfer_config **self.weight_transfer_config
) )
if isinstance(self.ir_op_priority, dict):
self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority)
# Setup plugins # Setup plugins
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
@@ -1293,6 +1297,7 @@ class EngineArgs:
title="KernelConfig", title="KernelConfig",
description=KernelConfig.__doc__, description=KernelConfig.__doc__,
) )
kernel_group.add_argument("--ir-op-priority", **kernel_kwargs["ir_op_priority"])
kernel_group.add_argument( kernel_group.add_argument(
"--enable-flashinfer-autotune", "--enable-flashinfer-autotune",
**kernel_kwargs["enable_flashinfer_autotune"], **kernel_kwargs["enable_flashinfer_autotune"],
@@ -1917,6 +1922,22 @@ class EngineArgs:
if self.moe_backend != "auto": if self.moe_backend != "auto":
kernel_config.moe_backend = self.moe_backend kernel_config.moe_backend = self.moe_backend
# Transfer top-level ir_op_priority into KernelConfig.ir_op_priority
for op_name, op_priority in asdict(self.ir_op_priority).items():
# Empty means unset
if not op_priority:
continue
# Priority cannot be set 2x for the same op
if getattr(kernel_config.ir_op_priority, op_name):
raise ValueError(
f"Op priority for {op_name} specified via both ir_op_priority "
f"and KernelConfig.ir_op_priority, only one allowed at a time."
)
# Set the attribute
setattr(kernel_config.ir_op_priority, op_name, op_priority)
load_config = self.create_load_config() load_config = self.create_load_config()
# Pass reasoning_parser into StructuredOutputsConfig # Pass reasoning_parser into StructuredOutputsConfig

View File

@@ -10,6 +10,7 @@ from typing import Any
import torch import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.ir
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -378,7 +379,13 @@ def set_forward_context(
) )
try: try:
with override_forward_context(forward_context): with (
override_forward_context(forward_context),
vllm_config.kernel_config.ir_op_priority.set_priority(),
vllm.ir.enable_torch_wrap(
vllm_config.compilation_config.ir_enable_torch_wrap
),
):
yield yield
finally: finally:
global last_logging_time, batchsize_logging_interval global last_logging_time, batchsize_logging_interval

6
vllm/ir/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from . import ops
from .op import enable_torch_wrap, register_op
__all__ = ["enable_torch_wrap", "register_op", "ops"]

414
vllm/ir/op.py Normal file
View File

@@ -0,0 +1,414 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import inspect
from collections.abc import Callable
from pathlib import Path
from typing import Any, ClassVar, overload
import torch
from torch.library import Library, infer_schema
from vllm.ir.util import hash_source, weak_cache
from vllm.logger import init_logger
from vllm.logging_utils import lazy, tensors_str_no_data
vllm_ir_lib = Library("vllm_ir", "FRAGMENT")
logger = init_logger(__name__)
RESERVED_PROVIDERS = ["native", "unfused"]
"""Providers that are reserved and cannot be used for custom implementations."""
_ENABLE_TORCH_WRAP: bool = True
"""Global override flag to control torch op layer wrapping."""
@contextlib.contextmanager
def enable_torch_wrap(enable: bool = True):
"""
Context manager to enable/disable torch custom op wrapping for vLLM IR ops.
When torch wrapping is disabled, the torch custom op layer is skipped
and IR ops dispatch directly to the implementation.
Helpful for avoiding torch dispatch overhead in eager mode
and avoiding the need for lowering for platforms not using Inductor.
"""
global _ENABLE_TORCH_WRAP
old = _ENABLE_TORCH_WRAP
try:
_ENABLE_TORCH_WRAP = enable
yield
finally:
_ENABLE_TORCH_WRAP = old
# 0-param decorator overload
@overload
def register_op(f: Callable[..., Any]) -> "IrOp": ...
# parametrized decorator overload
@overload
def register_op(
*,
name: str | None = None,
) -> Callable[[Callable[..., Any]], "IrOp"]: ...
def register_op(
f: Callable | None = None,
*,
name: str | None = None,
) -> "IrOp | Callable[[Callable], IrOp]":
"""
Register a new vLLM IR op.
:param f: the native implementation of the op
:param name: the name of the op, defaults to the function name
:return: the IrOp object if f is provided, otherwise a decorator
Example usage:
```python
@vllm.ir.register_op
def my_op(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
@vllm.ir.register_op(name="custom_mul")
def multiply(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y"""
def decorator(_f: Callable):
op_name: str = _f.__name__ if name is None else name
assert op_name not in IrOp.registry
op = IrOp(op_name, _f)
IrOp.registry[op_name] = op
return op
if f is not None:
return decorator(f)
return decorator
class IrOp:
registry: ClassVar[dict[str, "IrOp"]] = {}
name: str
impls: dict[str, "IrOpImpl"]
def __init__(self, name: str, native_impl: Callable):
self._py_signature = inspect.signature(native_impl)
if any(
p.kind == inspect.Parameter.KEYWORD_ONLY
for p in self._py_signature.parameters.values()
):
raise ValueError(
f"Op {name} has keyword-only arguments which are not currently "
f"supported. That's because kwargs are not allowed during lowering."
)
self.name = name
self.impls: dict[str, IrOpImpl] = {}
self._priority_impls: list[IrOpImpl] = []
self._schema_str = infer_schema(native_impl, mutates_args=[])
# native implementation
self.impls["native"] = IrOpImpl(
self, "native", native_impl, supported=True, supports_args=None
)
# By default, fake routes directly to native,
# can be overridden by register_fake
self._fake_fn = native_impl
# torch registration
vllm_ir_lib.define(self.name + self._schema_str)
# CompositeExplicitAutograd is not decomposed
# by ATen IR normalization in AOTAutograd
vllm_ir_lib.impl(
self.name, self._inner_call, dispatch_key="CompositeExplicitAutograd"
)
vllm_ir_lib._register_fake(self.name, self._fake_call)
assert hasattr(torch.ops.vllm_ir, name)
self.torch_op: torch._ops.OpOverload = getattr(torch.ops.vllm_ir, name).default
def register_fake(self, fn: Callable) -> Callable:
"""
Register a fake impl for the torch custom op. If this method is not called,
the native implementation is used directly for the fake implementation.
"""
self._fake_fn = fn
return fn
def _fake_call(self, *args, **kwargs) -> Any:
"""
Call to the fake implementation of the op. We use indirection because we want
users to be able to register fake later but also want it to fall back to native
directly by default, instead of going through the dispatching mechanism.
"""
return self._fake_fn(*args, **kwargs)
def register_impl(
self,
provider: str,
*,
supported: bool = True,
supports_args: Callable[..., bool] | None = None,
):
"""
Register an implementation for this custom op.
:param provider: The name of the provider, must be unique.
:param supported: Static support check, use this to check platform support.
:param supports_args: Dynamic arg support check, used for types and shapes.
:return: A decorator that registers the implementation.
The decorated function must have the same semantics and signature as
the native implementation.
The provider name must be unique and not one of the RESERVED_PROVIDERS.
The supported and supports_args parameters should not be used to implement
custom enablement logic based on global state (e.g. environment variables).
Instead, supported param should only be used to check for platform support
(e.g. whether a specific hardware or library is available).
supports_args should be used to check whether the provided arguments are
compatible with the implementation.
For custom enablement logic, set op impl priority.
Example:
```python
@my_op.register_impl("my_provider", supported=torch.cuda.is_available())
def my_provider_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
```
"""
assert provider not in RESERVED_PROVIDERS, (
f"Provider name {provider} is reserved."
)
def _register_impl(f: Callable):
impl = IrOpImpl(self, provider, f, supported, supports_args)
self.impls[provider] = impl
if self.get_priority():
logger.warning(
"Warning: registering new impl %s for op %s while priority is set.",
provider,
self.name,
)
return impl
return _register_impl
def _inner_call(self, *args, **kwargs) -> Any:
"""
Eager call to torch op lands here. When torch wrapping is disabled,
__call__ routes straight here instead of going through torch op dispatching.
"""
impl = self.dispatch(*args, **kwargs)
return impl.impl_fn(*args, **kwargs)
def apply_arg_defaults(self, args) -> tuple:
"""
Return args with default values applied.
Defaults are taken from the native implementation signature.
SHOULD NOT BE USED IN THE DISPATCH PATH (SLOW).
Only for Inductor lowering.
"""
bound_args = self._py_signature.bind(*args)
bound_args.apply_defaults()
return bound_args.args
def dispatch(self, *args, **kwargs) -> "IrOpImpl":
"""
Dispatch to the appropriate implementation based on current priority
and argument support checks. Returns the selected IrOpImpl.
THIS FUNCTION IS ON THE HOT PATH (OP DISPATCH), MUST BE FAST.
"""
if not self._priority_impls:
if not torch.compiler.is_compiling():
# Logging not compatible with Dynamo tracing
# (this code is exposed when torch wrapping is disabled)
logger.warning_once(
"Priority not set for op %s, using native implementation.",
self.name,
)
return self.impls["native"]
for impl in self._priority_impls:
if not impl.supported:
raise ValueError(
f"Implementation {impl.provider} for op {self.name} not supported. "
f"All implementations in priority list must be supported."
)
if impl.supports_args(*args, **kwargs):
return impl
if not torch.compiler.is_compiling():
logger.debug(
"Skipping provider %s because it does not support "
"%s with args=%s kwargs=%s",
impl.provider,
self.name,
lazy(lambda: tensors_str_no_data(args)),
lazy(lambda: tensors_str_no_data(kwargs)),
)
raise RuntimeError(
"Priority set incorrectly: the last implementation must "
"support all args (can be native). This is likely an internal bug"
)
def __call__(self, *args, **kwargs) -> Any:
if not _ENABLE_TORCH_WRAP:
return self._inner_call(*args, **kwargs)
return self.torch_op(*args, **kwargs)
def get_priority(self) -> list[str]:
"""Get the current dispatch priority for implementations for this op."""
return [p.provider for p in self._priority_impls]
@contextlib.contextmanager
def set_priority(self, priority: list[str]):
"""
Context manager to set the dispatch priority for implementations for this op.
"""
assert all(p in self.impls for p in priority), (
"All providers in priority must be registered implementations."
)
def filter_priority_impls(p_list: list[str]) -> list[IrOpImpl]:
filtered_impls = []
for p in p_list:
impl = self.impls[p]
if not impl.supported:
# Skip unsupported implementations
continue
filtered_impls.append(impl)
# If all args are supported, skip other implementations
if impl.supports_all_args:
return filtered_impls
logger.warning_once(
"Op %s: No implementation in priority list supports all args, "
"execution fallback to native is possible. To silence this warning, "
"explicitly add 'native' to the end of the priority list",
self.name,
)
filtered_impls.append(self.impls["native"])
return filtered_impls
# Temporarily set priority
old_priority_impls = self._priority_impls
try:
self._priority_impls = filter_priority_impls(priority)
yield
finally:
self._priority_impls = old_priority_impls
def supported_providers(self) -> list[str]:
return [p.provider for p in self.impls.values() if p.supported]
class IrOpImpl:
def __init__(
self,
op: IrOp,
provider: str,
impl_fn: Callable,
supported: bool,
supports_args: Callable[..., bool] | None,
):
assert provider not in op.impls, (
f"Implementation for provider {provider} already registered."
)
# Native also uses this path, so we allow it here.
assert provider == "native" or provider not in RESERVED_PROVIDERS
# Enforce the exact same schema as the native implementation.
# This takes care of names, types, and defaults.
schema = infer_schema(impl_fn, mutates_args=[])
if schema != op._schema_str:
raise ValueError(
f"Implementation for provider {provider} has schema '{schema}' which "
f"does not match native schema '{op._schema_str}' for op {op.name}."
)
if supports_args is not None:
if not callable(supports_args):
raise ValueError(
f"supports_args for provider {provider} must be a callable"
)
# We also manually validate the supports_args signature.
# Matching signatures allow faster dispatch on the hotpath.
# Check that supports_args does not have keyword-only parameters
supports_args_signature = inspect.signature(supports_args)
params = supports_args_signature.parameters
if any(p.kind == inspect.Parameter.KEYWORD_ONLY for p in params.values()):
raise ValueError(
f"supports_args for provider {provider} "
f"cannot have keyword-only parameters"
)
# Check that supports_args has the same total number of parameters
op_params = op._py_signature.parameters
if len(params) != len(op_params):
raise ValueError(
f"supports_args for provider {provider} must have the same number "
f"of parameters ({len(params)}) as the native implementation "
f"({len(op_params)})"
)
# Check that names and defaults match for supports_args
for p, op_p in zip(params.values(), op_params.values()):
if p.name != op_p.name:
raise ValueError(
f"supports_args for provider {provider} has parameter "
f"'{p.name}' which does not match native parameter "
f"'{op_p.name}'"
)
if p.default != op_p.default:
raise ValueError(
f"supports_args for provider {provider} has parameter "
f"'{p.name}' with default {p.default} which does not match "
f"native default {op_p.default}'"
)
self.op = op
self.provider = provider
self.impl_fn = impl_fn
self.supported = supported
self._supports_args = supports_args
@property
def supports_all_args(self) -> bool:
"""Check if this implementation supports all args unconditionally."""
return self._supports_args is None
def supports_args(self, *args, **kwargs) -> bool:
if self._supports_args is None:
return True
return self._supports_args(*args, **kwargs)
@weak_cache
def uuid(self):
"""
Compile-time hash to uniquely determine whether the implementation has changed.
Used by vllm-compile hash mechanism and torch.compile lowering pass uuid to
control the vLLM compile cache and AOTAutograd/Inductor caches respectively.
Source file contents do not change so we cache uuid.
TODO(luka): Cache the file hash as multiple impls are likely in the same file.
"""
sources = [Path(inspect.getfile(self.impl_fn))]
return hash_source(*sources)

5
vllm/ir/ops/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .layernorm import rms_norm
__all__ = ["rms_norm"]

22
vllm/ir/ops/layernorm.py Normal file
View File

@@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import Tensor
from ..op import register_op
@register_op
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
"""Weighted root-mean-square layer normalization"""
orig_dtype = x.dtype
x = x.to(torch.float32)
x_var = x if variance_size is None else x[..., :variance_size]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + epsilon)
x = x.to(orig_dtype)
if weight is not None:
x = x * weight
return x

61
vllm/ir/util.py Normal file
View File

@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import hashlib
import inspect
import types
import weakref
from pathlib import Path
from typing import Any
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
:return:
"""
hasher = hashlib.sha256()
for src in srcs:
if src is None:
src_str = "None"
elif isinstance(src, str):
src_str = src
elif isinstance(src, Path):
src_str = src.read_text()
elif isinstance(src, (types.FunctionType, type)):
src_str = inspect.getsource(src)
else:
# object instance
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.hexdigest()
def weak_lru_cache(maxsize: int | None = 128, typed: bool = False):
"""
LRU Cache decorator that keeps a weak reference to 'self'.
This avoids memory leakage, which happens when functools.lru_cache
stores a reference to self in the global cache.
Taken from: https://stackoverflow.com/a/68052994/5082708
"""
def wrapper(func):
@functools.lru_cache(maxsize, typed)
def _func(_self, *args, **kwargs):
return func(_self(), *args, **kwargs)
@functools.wraps(func)
def inner(self, *args, **kwargs):
return _func(weakref.ref(self), *args, **kwargs)
return inner
return wrapper
def weak_cache(user_function, /):
"""Simple weak equivalent to functools.cache"""
return weak_lru_cache(maxsize=None)(user_function)

View File

@@ -1,3 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Kernel implementations for vLLM.""" """Kernel implementations for vLLM."""
from . import aiter_ops, oink_ops, vllm_c, xpu_ops
__all__ = ["vllm_c", "aiter_ops", "oink_ops", "xpu_ops"]

79
vllm/kernels/aiter_ops.py Normal file
View File

@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
from torch import Tensor
from torch.library import Library
from vllm import ir
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
current_platform.import_kernels()
def is_aiter_found() -> bool:
from importlib.util import find_spec
return find_spec("aiter") is not None
aiter_lib = Library("vllm_aiter", "FRAGMENT")
"""
This library holds torch custom ops for wrapped AITER ops.
Many AITER ops want to remain invisible to torch.compile even after lowering.
They are thus wrapped into torch custom ops inside the IR op implementations.
"""
direct_register_aiter_op = functools.partial(
direct_register_custom_op, target_lib=aiter_lib
)
"""Syntactic sugar for registering AITER custom ops."""
AITER_SUPPORTED = is_aiter_found()
"""Most kernels in this file are supported if AITER is installed."""
rms_no_var_16bit_only = (
lambda x, weight, epsilon, variance_size=None: variance_size is None
and x.dtype
in (
torch.float16,
torch.bfloat16,
)
)
"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override."""
@ir.ops.rms_norm.register_impl(
"aiter", supports_args=rms_no_var_16bit_only, supported=AITER_SUPPORTED
)
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
assert variance_size is None
assert x.dtype in (torch.float16, torch.bfloat16)
if weight is None:
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
return torch.ops.vllm_aiter.rms_norm(x, weight, epsilon)
def _rms_norm_impl(x: Tensor, weight: Tensor, variance_epsilon: float) -> Tensor:
from aiter import rms_norm
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rms_norm(x, weight, variance_epsilon)
def _rms_norm_fake(x: Tensor, weight: Tensor, variance_epsilon: float) -> Tensor:
return torch.empty_like(x)
direct_register_aiter_op(
op_name="rms_norm", op_func=_rms_norm_impl, fake_impl=_rms_norm_fake
)

77
vllm/kernels/oink_ops.py Normal file
View File

@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import ir
from vllm.platforms import current_platform
OINK_AVAILABLE = current_platform.has_device_capability(100) and hasattr(
torch.ops, "oink"
)
def has_oink_op(name: str) -> bool:
"""Check if a specific oink op is registered."""
return OINK_AVAILABLE and hasattr(torch.ops.oink, name)
def _can_view_as_2d(x: torch.Tensor) -> bool:
"""Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
if x.dim() < 2:
return False
if x.dim() == 2:
return True
# For a view(-1, N) to be valid, all leading dims must be contiguous with
# respect to each other (size-1 dims are ignored).
for dim in range(x.dim() - 1):
# Strides for size-1 dims are irrelevant and can be arbitrary.
if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
dim + 1
):
return False
return True
def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
"""Return True if x_2d meets Oink's pointer-path stride constraints."""
if x_2d.dim() != 2:
return False
if x_2d.stride(1) != 1:
return False
# Match Oink's vectorization constraint: stride(0) divisible by 256b.
if x_2d.dtype in (torch.float16, torch.bfloat16):
divby = 16
elif x_2d.dtype == torch.float32:
divby = 8
else:
return False
return (x_2d.stride(0) % divby) == 0
oink_rms_supported = (
lambda x, weight, epsilon, variance_size=None: variance_size is None
and weight is not None
and x.dim() >= 2
and x.dtype == weight.dtype
and weight.is_contiguous()
and _can_view_as_2d(x)
and _is_oink_stride_compatible_2d(x.view(-1, x.shape[-1]))
)
"""
Oink rms only supports 2d-like inputs with contiguous weight
and no variance_size override.
"""
@ir.ops.rms_norm.register_impl(
"oink", supports_args=oink_rms_supported, supported=has_oink_op("rmsnorm")
)
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor | None,
epsilon: float,
variance_size: int | None = None,
) -> torch.Tensor:
assert variance_size is None
x_2d = x.view(-1, x.shape[-1])
return torch.ops.oink.rmsnorm(x_2d, weight, epsilon).view_as(x)

30
vllm/kernels/vllm_c.py Normal file
View File

@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import Tensor
from vllm import ir
from vllm.platforms import current_platform
current_platform.import_kernels()
CUDA_ALIKE = current_platform.is_cuda_alike()
"""Most kernels in this file are supported on all CUDA-alike platforms."""
rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None
"""vLLM kernel does not support variance_size parameter."""
@ir.ops.rms_norm.register_impl(
"vllm_c", supports_args=rms_no_var_size, supported=CUDA_ALIKE
)
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
if weight is None:
# Kernel requires weight tensor, pass ones
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
assert variance_size is None
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
torch.ops._C.rms_norm(output, x, weight, epsilon)
return output

36
vllm/kernels/xpu_ops.py Normal file
View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import Tensor
from vllm import ir
from vllm.platforms import current_platform
current_platform.import_kernels()
def is_xpu_kernels_found() -> bool:
from importlib.util import find_spec
return find_spec("vllm_xpu_kernels") is not None
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None
@ir.ops.rms_norm.register_impl(
"xpu_kernels", supports_args=rms_no_var, supported=XPU_KERNELS_SUPPORTED
)
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
if weight is None:
# Kernel requires weight tensor, pass ones
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
assert variance_size is None
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
torch.ops._C.rms_norm(output, x, weight, epsilon)
return output

View File

@@ -8,6 +8,7 @@ from vllm.logging_utils.access_log_filter import (
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
from vllm.logging_utils.lazy import lazy from vllm.logging_utils.lazy import lazy
from vllm.logging_utils.log_time import logtime from vllm.logging_utils.log_time import logtime
from vllm.logging_utils.torch_tensor import tensors_str_no_data
__all__ = [ __all__ = [
"NewLineFormatter", "NewLineFormatter",
@@ -16,4 +17,5 @@ __all__ = [
"create_uvicorn_log_config", "create_uvicorn_log_config",
"lazy", "lazy",
"logtime", "logtime",
"tensors_str_no_data",
] ]

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
def tensors_str_no_data(arg: Any):
from torch._tensor_str import printoptions
with printoptions(threshold=1, edgeitems=0):
return str(arg)

View File

@@ -6,7 +6,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm import _oink_ops, envs # Import kernels
import vllm.kernels # noqa: F401
from vllm import _oink_ops, envs, ir
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
@@ -51,23 +53,6 @@ def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
return (x_2d.stride(0) % divby) == 0 return (x_2d.stride(0) % divby) == 0
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
from vllm import _custom_ops as ops
if envs.VLLM_BATCH_INVARIANT:
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
weight,
variance_epsilon,
)
return out
def fused_add_rms_norm( def fused_add_rms_norm(
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
@@ -105,23 +90,16 @@ def poly_norm(
return out return out
def dispatch_rocm_rmsnorm_func( def dispatch_rocm_rmsnorm_func(dtype: torch.dtype, use_aiter: bool = False):
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
use_aiter = use_aiter and dtype in [ use_aiter = use_aiter and dtype in [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
] ]
if use_aiter and with_fused_add:
return rocm_aiter_ops.rms_norm2d_with_add
if use_aiter: if use_aiter:
return rocm_aiter_ops.rms_norm return rocm_aiter_ops.rms_norm2d_with_add
else:
# fall back to CUDA implementation
if with_fused_add:
return fused_add_rms_norm return fused_add_rms_norm
return rms_norm
# --8<-- [start:rms_norm] # --8<-- [start:rms_norm]
@@ -158,20 +136,14 @@ class RMSNorm(CustomOp):
if current_platform.is_rocm(): if current_platform.is_rocm():
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled() aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
with_fused_add=False,
dtype=weight_dtype,
use_aiter=aiter_rmsnorm_enabled,
)
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
) )
# Optional: enable Oink Blackwell RMSNorm custom-op fast path on # Optional: enable Oink Blackwell RMSNorm custom-op fast path on
# compatible CUDA devices (e.g., SM100) when the external Oink # compatible CUDA devices (e.g., SM100) when the external Oink
# package is available. This is detected once at construction time # package is available. This is detected once at construction time
# to avoid per-call device queries in the hot path. # to avoid per-call device queries in the hot path.
self._use_oink_rmsnorm = False
self._use_oink_fused_add_rmsnorm = False self._use_oink_fused_add_rmsnorm = False
if ( if (
not current_platform.is_rocm() not current_platform.is_rocm()
@@ -203,7 +175,6 @@ class RMSNorm(CustomOp):
try: try:
device_index = torch.accelerator.current_device_index() device_index = torch.accelerator.current_device_index()
if _oink_ops.is_oink_available_for_device(device_index): if _oink_ops.is_oink_available_for_device(device_index):
self._use_oink_rmsnorm = True
self._use_oink_fused_add_rmsnorm = ( self._use_oink_fused_add_rmsnorm = (
_oink_ops.has_fused_add_rms_norm() _oink_ops.has_fused_add_rms_norm()
) )
@@ -215,7 +186,6 @@ class RMSNorm(CustomOp):
"RMSNorm; falling back to vLLM RMSNorm. Error: %s", "RMSNorm; falling back to vLLM RMSNorm. Error: %s",
e, e,
) )
self._use_oink_rmsnorm = False
self._use_oink_fused_add_rmsnorm = False self._use_oink_fused_add_rmsnorm = False
@staticmethod @staticmethod
@@ -270,6 +240,10 @@ class RMSNorm(CustomOp):
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if residual is None:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
)
return self.forward_static( return self.forward_static(
x, x,
@@ -286,35 +260,14 @@ class RMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None and not envs.VLLM_BATCH_INVARIANT:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
)
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
# Optional Oink SM100 fast path (no residual). This path is
# torch.compile-friendly via torch.ops.oink.rmsnorm and preserves
# 2D layouts (including padded rows) when using the Oink
# pointer-based kernel.
if (
residual is None
and getattr(self, "_use_oink_rmsnorm", False)
and x.is_cuda
and x.dim() >= 2
and self.has_weight
and not envs.VLLM_BATCH_INVARIANT
and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous()
):
orig_shape = x.shape
hidden_size = orig_shape[-1]
if _can_view_as_2d(x):
x_2d = x.view(-1, hidden_size)
if _is_oink_stride_compatible_2d(x_2d):
y_2d = _oink_ops.rmsnorm(
x_2d,
self.weight.data,
self.variance_epsilon,
)
return y_2d.view(orig_shape)
# Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place). # Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
# This mirrors vLLM's fused_add_rms_norm semantics by mutating both # This mirrors vLLM's fused_add_rms_norm semantics by mutating both
# `x` (normalized output) and `residual` (residual-out buffer). # `x` (normalized output) and `residual` (residual-out buffer).
@@ -356,29 +309,34 @@ class RMSNorm(CustomOp):
) )
return x, residual return x, residual
add_residual = residual is not None if residual is not None:
if add_residual:
return fused_add_rms_norm( return fused_add_rms_norm(
x, residual, self.weight.data, self.variance_epsilon x, residual, self.weight.data, self.variance_epsilon
) )
else: else:
return rms_norm(x, self.weight.data, self.variance_epsilon) assert envs.VLLM_BATCH_INVARIANT
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
def forward_hip( def forward_hip(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None and not envs.VLLM_BATCH_INVARIANT:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
)
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
add_residual = residual is not None if residual is not None:
if add_residual:
return self.rocm_norm_func_with_add( return self.rocm_norm_func_with_add(
x, residual, self.weight.data, self.variance_epsilon x, residual, self.weight.data, self.variance_epsilon
) )
else: else:
return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon) assert envs.VLLM_BATCH_INVARIANT
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
def forward_xpu( def forward_xpu(
self, self,

View File

@@ -30,6 +30,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.config.kernel import IrOpPriorityConfig
from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.v1.attention.selector import AttentionSelectorConfig
else: else:
VllmConfig = None VllmConfig = None
@@ -550,6 +551,26 @@ class CudaPlatformBase(Platform):
def use_custom_op_collectives(cls) -> bool: def use_custom_op_collectives(cls) -> bool:
return True return True
@classmethod
def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConfig:
from vllm.config.compilation import CompilationMode
from vllm.config.kernel import IrOpPriorityConfig
# Native used by default when compiling,
# use vllm_c kernels where available when no codegen
cc = vllm_config.compilation_config
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
default = ["native"] if using_inductor else ["vllm_c", "native"]
# Use oink if enabled for rms_norm
# TODO(Laurawly/luka): remove this env var,
# users can just use IR op priority directly
rms_norm = default
if envs.VLLM_USE_OINK_OPS:
rms_norm = ["oink"] + default
return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@@ -17,6 +17,7 @@ if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup from torch.distributed import PrefixStore, ProcessGroup
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kernel import IrOpPriorityConfig
from vllm.inputs import EngineInput from vllm.inputs import EngineInput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@@ -931,6 +932,16 @@ class Platform:
"num_compute_units is not implemented for the current platform." "num_compute_units is not implemented for the current platform."
) )
@classmethod
def get_default_ir_op_priority(
cls, vllm_config: "VllmConfig"
) -> "IrOpPriorityConfig":
"""Get the default IR op priority for the current platform."""
from vllm.config.kernel import IrOpPriorityConfig
# Native always used by default. Platforms can override this behavior.
return IrOpPriorityConfig.with_default(["native"])
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED

View File

@@ -19,6 +19,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kernel import IrOpPriorityConfig
from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.v1.attention.selector import AttentionSelectorConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -903,3 +904,32 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def use_custom_op_collectives(cls) -> bool: def use_custom_op_collectives(cls) -> bool:
return True return True
@classmethod
def get_default_ir_op_priority(
cls, vllm_config: "VllmConfig"
) -> "IrOpPriorityConfig":
from vllm.config.compilation import CompilationMode
from vllm.config.kernel import IrOpPriorityConfig
# Native used by default when compiling,
# use vllm_c kernels where available when no codegen
# TODO(luka/TJ) use aiter, vllm_c, native by default on ROCm
cc = vllm_config.compilation_config
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
default = ["native"] if using_inductor else ["vllm_c", "native"]
# This (mostly) preserves previous CustomOp behavior
# Necessary on ROCm because it's common that users
# enable rms_norm to use the aiter kernel.
# TODO(luka/TJ) remove env vars completely
if (
cc.is_custom_op_enabled("rms_norm")
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_RMSNORM
):
rms_norm = ["aiter"] + default
else:
rms_norm = default
return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)

View File

@@ -21,6 +21,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kernel import IrOpPriorityConfig
from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.v1.attention.selector import AttentionSelectorConfig
else: else:
VllmConfig = None VllmConfig = None
@@ -257,6 +258,21 @@ class XPUPlatform(Platform):
) )
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
@classmethod
def get_default_ir_op_priority(
cls, vllm_config: "VllmConfig"
) -> "IrOpPriorityConfig":
from vllm.config.compilation import CompilationMode
from vllm.config.kernel import IrOpPriorityConfig
# Native used by default when compiling,
# use fused kernels where available when no codegen
cc = vllm_config.compilation_config
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
default = ["native"] if using_inductor else ["xpu_kernels", "native"]
return IrOpPriorityConfig.with_default(default)
@classmethod @classmethod
def device_count(cls) -> int: def device_count(cls) -> int:
return torch.xpu.device_count() return torch.xpu.device_count()