[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> Signed-off-by: chzhang <chaojun.zhang@intel.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com> Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com> Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com> Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
This commit is contained in:
@@ -2,6 +2,16 @@ group: Kernels
|
||||
depends_on:
|
||||
- image-build
|
||||
steps:
|
||||
- label: vLLM IR Tests
|
||||
timeout_in_minutes: 10
|
||||
working_dir: "/vllm-workspace/"
|
||||
source_file_dependencies:
|
||||
- vllm/ir
|
||||
- vllm/kernels
|
||||
commands:
|
||||
- pytest -v -s tests/ir
|
||||
- pytest -v -s tests/kernels/ir
|
||||
|
||||
- label: Kernels Core Operation Test
|
||||
timeout_in_minutes: 75
|
||||
source_file_dependencies:
|
||||
|
||||
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@@ -13,6 +13,9 @@
|
||||
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
|
||||
/vllm/model_executor/model_loader @22quinn
|
||||
/vllm/model_executor/layers/batch_invariant.py @yewentao256
|
||||
/vllm/ir @ProExpertProg
|
||||
/vllm/kernels/ @ProExpertProg @tjtanaa
|
||||
/vllm/kernels/helion @ProExpertProg @zou3519
|
||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
|
||||
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
|
||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
@@ -74,6 +77,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
|
||||
/tests/evals @mgoin @vadiklyutiy
|
||||
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
|
||||
/tests/kernels/ir @ProExpertProg @tjtanaa
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety
|
||||
|
||||
@@ -8,7 +8,7 @@ from copy import deepcopy
|
||||
|
||||
import depyf
|
||||
from torch import fx
|
||||
from torch._ops import OpOverload
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
from torch.fx._utils import lazy_format_graph_code
|
||||
|
||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||
@@ -90,7 +90,9 @@ class TestBackend:
|
||||
# assign by reference, will reflect the final state of the graph
|
||||
self.final_graph = graph
|
||||
|
||||
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
|
||||
def check_before_ops(
|
||||
self, ops: Sequence[OpOverload | OpOverloadPacket], fully_replaced=True
|
||||
):
|
||||
for op in ops:
|
||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||
@@ -99,13 +101,19 @@ class TestBackend:
|
||||
if fully_replaced:
|
||||
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
|
||||
|
||||
def check_after_ops(self, ops: Sequence[OpOverload]):
|
||||
def check_after_ops(self, ops: Sequence[OpOverload | OpOverloadPacket]):
|
||||
for op in ops:
|
||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
|
||||
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
|
||||
|
||||
def op_count(self, op: OpOverload, before=False) -> int:
|
||||
def op_count(self, op: OpOverload | OpOverloadPacket, before=False) -> int:
|
||||
graph = self.graph_pre_pass if before else self.graph_post_pass
|
||||
return len(list(find_op_nodes(op, graph)))
|
||||
|
||||
def print_graphs(self):
|
||||
print("=== Graph before custom passes ===")
|
||||
print(self.graph_pre_pass.python_code(root_module="self", verbose=True).src)
|
||||
print("=== Graph after custom passes ===")
|
||||
print(self.graph_post_pass.python_code(root_module="self", verbose=True).src)
|
||||
|
||||
@@ -99,6 +99,8 @@ def test_tp1_fp8_fusions(
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
custom_ops=custom_ops.split(","),
|
||||
@@ -166,6 +168,7 @@ def test_tp1_fp4_fusions(
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
|
||||
@@ -68,6 +68,7 @@ def test_tp2_ar_rms_fp8_fusions(
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
@@ -128,6 +129,7 @@ def test_tp2_ar_rms_fp4_fusions(
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
@@ -182,6 +184,7 @@ def test_tp2_ar_rms_fusions(
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
|
||||
@@ -58,6 +58,7 @@ def test_tp2_async_tp_fp8_fusions(
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
@@ -121,6 +122,7 @@ def test_tp2_async_tp_fusions(
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
|
||||
@@ -9,7 +9,6 @@ from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestFP8Layer, multi_gpu_test
|
||||
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.compilation.passes.fx_utils import find_auto_fn
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
@@ -86,13 +85,14 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
if RMSNorm.enabled():
|
||||
return [
|
||||
torch.ops._C.rms_norm.default,
|
||||
return (
|
||||
[torch.ops.vllm_ir.rms_norm]
|
||||
+ [
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
]
|
||||
else:
|
||||
return []
|
||||
if RMSNorm.enabled()
|
||||
else []
|
||||
)
|
||||
|
||||
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
@@ -321,4 +321,4 @@ def sequence_parallelism_pass_on_test_model(
|
||||
assert backend.op_count(op, before=False) == 4
|
||||
|
||||
for op in model.ops_in_model():
|
||||
find_auto_fn(backend.graph_post_pass.nodes, op)
|
||||
assert backend.op_count(op, before=False) > 0
|
||||
|
||||
0
tests/compile/passes/ir/__init__.py
Normal file
0
tests/compile/passes/ir/__init__.py
Normal file
69
tests/compile/passes/ir/test_lowering.py
Normal file
69
tests/compile/passes/ir/test_lowering.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import vllm.kernels # noqa: F401 to register kernels
|
||||
from vllm import ir
|
||||
from vllm.compilation.passes.ir.lowering_pass import (
|
||||
VllmIRLoweringPass,
|
||||
)
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.ir import ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...backend import TestBackend
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, hidden_size=16, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.weight = torch.ones(hidden_size, dtype=torch.bfloat16)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = x + 4.0
|
||||
x2 = ops.rms_norm(x1, self.weight, 1e-5)
|
||||
x3 = x2 * 5.0
|
||||
# no weight
|
||||
x4 = ops.rms_norm(x3, None, 1e-5)
|
||||
x5 = x4 / 2.0
|
||||
# dispatch to native due to variance_size parameter
|
||||
x6 = ops.rms_norm(x5, self.weight, 1e-5, self.hidden_size // 2)
|
||||
return x6 + 3.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rms_provider", ops.rms_norm.supported_providers())
|
||||
def test_lowering_rms_norm(rms_provider, default_vllm_config):
|
||||
torch.set_default_device(current_platform.device_type)
|
||||
|
||||
lowering_pass = VllmIRLoweringPass(get_current_vllm_config())
|
||||
backend = TestBackend(lowering_pass)
|
||||
backend_unlowered = TestBackend()
|
||||
|
||||
model = Model()
|
||||
x = torch.randn(8, 16, dtype=torch.bfloat16)
|
||||
with (
|
||||
ops.rms_norm.set_priority([rms_provider, "native"]),
|
||||
ir.enable_torch_wrap(True),
|
||||
):
|
||||
compiled_model = torch.compile(model, backend=backend, fullgraph=True)
|
||||
compiled_unlowered_model = torch.compile(
|
||||
model, backend=backend_unlowered, fullgraph=True
|
||||
)
|
||||
output = compiled_model(x)
|
||||
output_unlowered = compiled_unlowered_model(x)
|
||||
|
||||
selected = lowering_pass.selected_impls["rms_norm"]
|
||||
assert len(selected) == 3
|
||||
assert selected["rms_norm"] == rms_provider
|
||||
assert selected["rms_norm_1"] == rms_provider
|
||||
assert selected["rms_norm_2"] == "native"
|
||||
|
||||
# Compiled function guards on global value, avoid recompilation
|
||||
with ir.enable_torch_wrap(True):
|
||||
output2 = compiled_model(x)
|
||||
|
||||
torch.testing.assert_close(output_unlowered, output)
|
||||
torch.testing.assert_close(output_unlowered, output2)
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.config
|
||||
import vllm.ir.ops
|
||||
import vllm.plugins
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestBlockFP8Layer, TestFP8Layer
|
||||
@@ -51,7 +52,6 @@ from vllm.utils.deep_gemm import (
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
# Kernel and group_shape combinations: (kernel, group_shape)
|
||||
@@ -246,10 +246,8 @@ class TestModel(torch.nn.Module):
|
||||
]
|
||||
|
||||
def ops_in_model_before_partial(self):
|
||||
return (
|
||||
[RMS_OP, RMS_ADD_OP]
|
||||
if self.enable_rms_norm_custom_op
|
||||
else [torch.ops.aten.rsqrt]
|
||||
return [torch.ops.vllm_ir.rms_norm] + (
|
||||
[RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt]
|
||||
)
|
||||
|
||||
|
||||
@@ -340,7 +338,10 @@ def test_fusion_rmsnorm_quant(
|
||||
),
|
||||
)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
with (
|
||||
vllm.config.set_current_vllm_config(vllm_config),
|
||||
vllm_config.kernel_config.ir_op_priority.set_priority(),
|
||||
):
|
||||
# Setup device before model creation
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
@@ -370,8 +371,9 @@ def test_fusion_rmsnorm_quant(
|
||||
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
||||
if not enable_rms_norm_custom_op:
|
||||
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
||||
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
||||
assert n_add_nodes(backend.graph_pre_pass) == 7
|
||||
# rms_norm is IR, not included
|
||||
# 6 = 3x2 (3xRMS_ADD, 2 each)
|
||||
assert n_add_nodes(backend.graph_pre_pass) == 6
|
||||
assert n_add_nodes(backend.graph_post_pass) == 2
|
||||
|
||||
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from vllm.compilation.passes.fusion.matcher_utils import (
|
||||
FLASHINFER_ROTARY_OP,
|
||||
RMS_OP,
|
||||
ROTARY_OP,
|
||||
)
|
||||
from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
|
||||
@@ -100,13 +100,8 @@ class QKNormRoPETestModel(torch.nn.Module):
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
return q, k, v
|
||||
|
||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
||||
ops = []
|
||||
if self.enable_rms_norm_custom_op:
|
||||
ops.append(RMS_OP)
|
||||
else:
|
||||
ops.append(RSQRT_OP)
|
||||
|
||||
def ops_in_model_before(self) -> list[OpOverload | OpOverloadPacket]:
|
||||
ops: list[OpOverload | OpOverloadPacket] = [torch.ops.vllm_ir.rms_norm]
|
||||
if self.enable_rope_custom_op:
|
||||
if self.rotary_emb.use_flashinfer:
|
||||
ops.append(FLASHINFER_ROTARY_OP)
|
||||
@@ -116,7 +111,7 @@ class QKNormRoPETestModel(torch.nn.Module):
|
||||
ops.append(INDEX_SELECT_OP)
|
||||
return ops
|
||||
|
||||
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
|
||||
def ops_in_model_after(self) -> list[OpOverload | OpOverloadPacket]:
|
||||
return [FUSED_QK_ROPE_OP]
|
||||
|
||||
|
||||
@@ -166,7 +161,10 @@ def test_qk_norm_rope_fusion(
|
||||
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
||||
T = 5
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
vllm_config.kernel_config.ir_op_priority.set_priority(),
|
||||
):
|
||||
model = QKNormRoPETestModel(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
|
||||
@@ -1622,3 +1622,26 @@ def fresh_vllm_cache(monkeypatch, use_fresh_inductor_cache):
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def disable_log_dedup(monkeypatch):
|
||||
"""
|
||||
Disable log deduplication such that warning_once and info_once always print.
|
||||
"""
|
||||
|
||||
# Patch logger._print_warning_once to remove the lru_cache decorator
|
||||
from vllm import logger
|
||||
|
||||
original_print_warning_once = logger._print_warning_once
|
||||
original_print_info_once = logger._print_info_once
|
||||
original_print_debug_once = logger._print_debug_once
|
||||
|
||||
logger._print_warning_once = original_print_warning_once.__wrapped__
|
||||
logger._print_info_once = original_print_info_once.__wrapped__
|
||||
logger._print_debug_once = original_print_debug_once.__wrapped__
|
||||
|
||||
yield
|
||||
logger._print_warning_once = original_print_warning_once
|
||||
logger._print_info_once = original_print_info_once
|
||||
logger._print_debug_once = original_print_debug_once
|
||||
|
||||
@@ -523,3 +523,20 @@ def test_human_readable_model_len():
|
||||
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
|
||||
with pytest.raises(ArgumentError):
|
||||
parser.parse_args(["--max-model-len", invalid])
|
||||
|
||||
|
||||
def test_ir_op_priority():
|
||||
from vllm.config.kernel import IrOpPriorityConfig, KernelConfig
|
||||
|
||||
ir_op_priority = IrOpPriorityConfig(rms_norm=["vllm_c"])
|
||||
cfg1 = EngineArgs(ir_op_priority=ir_op_priority).create_engine_config()
|
||||
cfg2 = EngineArgs(
|
||||
kernel_config=KernelConfig(ir_op_priority=ir_op_priority)
|
||||
).create_engine_config()
|
||||
assert cfg1.kernel_config.ir_op_priority == cfg2.kernel_config.ir_op_priority
|
||||
|
||||
with pytest.raises(ValueError, match="rms_norm"):
|
||||
_ = EngineArgs(
|
||||
ir_op_priority=ir_op_priority,
|
||||
kernel_config=KernelConfig(ir_op_priority=ir_op_priority),
|
||||
).create_engine_config()
|
||||
|
||||
497
tests/ir/test_op.py
Normal file
497
tests/ir/test_op.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
import vllm.ir.op
|
||||
from vllm.ir.op import RESERVED_PROVIDERS, IrOp, IrOpImpl
|
||||
|
||||
# This should not exist
|
||||
assert "_custom_add" not in IrOp.registry
|
||||
|
||||
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@vllm.ir.register_op
|
||||
def _custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y
|
||||
|
||||
|
||||
def test_registration_overloads():
|
||||
assert all(
|
||||
n not in IrOp.registry for n in ["_custom_sub", "_custom_mul", "_custom_div"]
|
||||
)
|
||||
|
||||
# Calling with decorator
|
||||
@vllm.ir.register_op()
|
||||
def _custom_sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x - y
|
||||
|
||||
assert _custom_sub.name == "_custom_sub"
|
||||
assert _custom_sub is IrOp.registry["_custom_sub"]
|
||||
|
||||
# Custom name
|
||||
@vllm.ir.register_op(name="_custom_mul")
|
||||
def custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x * y
|
||||
|
||||
assert custom_mul.name == "_custom_mul"
|
||||
assert custom_mul is IrOp.registry["_custom_mul"]
|
||||
|
||||
# Direct construction does not register directly
|
||||
def _custom_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x / y
|
||||
|
||||
custom_div = IrOp("_custom_div", _custom_div)
|
||||
assert custom_div.name == "_custom_div"
|
||||
assert "_custom_div" not in IrOp.registry
|
||||
|
||||
# Duplicate op registration not allowed
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
@vllm.ir.register_op
|
||||
def _custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x * y - 100
|
||||
|
||||
|
||||
def test_no_kw_only_args():
|
||||
# kw-only args not supported
|
||||
with pytest.raises(ValueError, match="keyword-only arguments"):
|
||||
|
||||
@vllm.ir.register_op
|
||||
def _custom_kwarg_op(
|
||||
x: torch.Tensor, y: torch.Tensor, *, kwarg: int = 0
|
||||
) -> torch.Tensor:
|
||||
return x + y + kwarg
|
||||
|
||||
assert "_custom_kwarg_op" not in IrOp.registry
|
||||
|
||||
|
||||
class TestIrOpCustomAdd:
|
||||
# Registration invariants
|
||||
def test_decorated_object(self):
|
||||
"""Make sure that referring directly to an op is correct"""
|
||||
assert isinstance(_custom_add, IrOp)
|
||||
assert "_custom_add" in IrOp.registry
|
||||
assert _custom_add is IrOp.registry["_custom_add"]
|
||||
|
||||
def test_torch_op_is_registered(self):
|
||||
assert hasattr(torch.ops.vllm_ir, "_custom_add")
|
||||
assert callable(torch.ops.vllm_ir._custom_add.default)
|
||||
|
||||
# Semantic correctness
|
||||
def test_semantics_match_native(self):
|
||||
x = torch.randn(4, 5)
|
||||
y = torch.randn(4, 5)
|
||||
|
||||
# Calls native by default
|
||||
out = _custom_add(x, y)
|
||||
ref = x + y
|
||||
|
||||
torch.testing.assert_close(out, ref)
|
||||
|
||||
# -------------------------
|
||||
# Implementation registration
|
||||
# -------------------------
|
||||
|
||||
def test_register_impl_is_non_intrusive(self):
|
||||
@_custom_add.register_impl("dummy_provider")
|
||||
def dummy_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 123
|
||||
|
||||
assert "dummy_provider" in _custom_add.impls
|
||||
assert isinstance(_custom_add.impls["dummy_provider"], IrOpImpl)
|
||||
|
||||
x = torch.ones(2, 2)
|
||||
y = torch.ones(2, 2)
|
||||
|
||||
# Native semantics must still hold
|
||||
torch.testing.assert_close(_custom_add(x, y), x + y)
|
||||
|
||||
def test_schema_contains_tensor_signature(self):
|
||||
schema = _custom_add._schema_str
|
||||
|
||||
assert "Tensor" in schema
|
||||
assert "-> Tensor" in schema
|
||||
|
||||
# -------------------------
|
||||
# FX visibility
|
||||
# -------------------------
|
||||
|
||||
@pytest.mark.parametrize("enable_torch_wrap", [True, False])
|
||||
@pytest.mark.parametrize("symbolic_trace", [True, False])
|
||||
def test_trace_sees_single_custom_op(
|
||||
self, symbolic_trace: bool, enable_torch_wrap: bool
|
||||
):
|
||||
def fn(x, y):
|
||||
return _custom_add(x, y)
|
||||
|
||||
def find_fn(target: Any, gm: fx.GraphModule):
|
||||
return gm.graph.find_nodes(op="call_function", target=target)
|
||||
|
||||
with pytest.raises(CustomError), vllm.ir.enable_torch_wrap(enable_torch_wrap):
|
||||
if symbolic_trace:
|
||||
gm = torch.fx.symbolic_trace(fn)
|
||||
else:
|
||||
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
|
||||
|
||||
x1, y1 = torch.rand(5, 4), torch.rand(5, 4)
|
||||
out_fx = gm(x1, y1)
|
||||
out_eager = fn(x1, y1)
|
||||
|
||||
# raise error to check enable_torch_wrap context restored correctly
|
||||
raise CustomError
|
||||
|
||||
# check behavior matches eager in all cases
|
||||
torch.testing.assert_close(out_fx, out_eager)
|
||||
|
||||
# check that IR nodes only appear if enable_torch_wrap=True
|
||||
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
|
||||
if enable_torch_wrap:
|
||||
assert len(ir_nodes) == 1, gm.code
|
||||
else:
|
||||
assert len(ir_nodes) == 0, gm.code
|
||||
|
||||
# with torch wrapping enabled (default), IR nodes appear
|
||||
if symbolic_trace:
|
||||
gm = torch.fx.symbolic_trace(fn)
|
||||
else:
|
||||
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
|
||||
|
||||
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
|
||||
assert len(ir_nodes) == 1, gm.code
|
||||
|
||||
|
||||
@_custom_add.register_impl("impl_a")
|
||||
def impl_a(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 10
|
||||
|
||||
|
||||
@_custom_add.register_impl("impl_b")
|
||||
def impl_b(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 20
|
||||
|
||||
|
||||
@_custom_add.register_impl("impl_even", supports_args=lambda x, y: x.size(1) % 2 == 0)
|
||||
def impl_even(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 50
|
||||
|
||||
|
||||
class TestIrOpImplDispatch:
|
||||
def test_register_impl(self):
|
||||
assert "impl_a" in _custom_add.impls
|
||||
impl = _custom_add.impls["impl_a"]
|
||||
|
||||
assert impl is impl_a
|
||||
assert impl.op is _custom_add
|
||||
assert impl.provider == "impl_a"
|
||||
assert callable(impl.impl_fn)
|
||||
|
||||
# Test duplicate registration rejected
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
@_custom_add.register_impl("impl_a")
|
||||
def impl_a_dup(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 30
|
||||
|
||||
# Check the original impl is still intact
|
||||
assert _custom_add.impls["impl_a"] is impl_a
|
||||
|
||||
# Check support all args
|
||||
assert impl_a.supports_all_args
|
||||
assert impl_b.supports_all_args
|
||||
assert not impl_even.supports_all_args
|
||||
|
||||
def test_reserved_provider_rejected(self):
|
||||
for provider in RESERVED_PROVIDERS:
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
@_custom_add.register_impl(provider)
|
||||
def bad_impl(x, y):
|
||||
return x + y
|
||||
|
||||
def test_set_priority_scoped(self):
|
||||
assert _custom_add.get_priority() == []
|
||||
|
||||
with _custom_add.set_priority(["impl_even", "impl_b"]):
|
||||
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
|
||||
|
||||
# Check nesting
|
||||
with _custom_add.set_priority(["impl_b"]):
|
||||
assert _custom_add.get_priority() == ["impl_b"]
|
||||
|
||||
# Restored
|
||||
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
|
||||
|
||||
# Check that exception restores priority
|
||||
with pytest.raises(CustomError), _custom_add.set_priority(["impl_a"]):
|
||||
assert _custom_add.get_priority() == ["impl_a"]
|
||||
raise CustomError
|
||||
|
||||
# Restored again
|
||||
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
|
||||
|
||||
# Restored to empty
|
||||
assert _custom_add.get_priority() == []
|
||||
|
||||
def test_dispatch_priority_order(self):
|
||||
x = torch.tensor(1, dtype=torch.int32)
|
||||
y = torch.tensor(2, dtype=torch.int32)
|
||||
|
||||
with _custom_add.set_priority(["impl_b", "impl_a"]):
|
||||
assert _custom_add.dispatch(x, y) is impl_b
|
||||
out1 = _custom_add(x, y)
|
||||
out2 = torch.ops.vllm_ir._custom_add(x, y)
|
||||
|
||||
with _custom_add.set_priority(["impl_a"]):
|
||||
assert _custom_add.dispatch(x, y) is impl_a
|
||||
out3 = _custom_add(x, y)
|
||||
out4 = torch.ops.vllm_ir._custom_add(x, y)
|
||||
|
||||
# impl_b
|
||||
assert out1.item() == 1 + 2 + 20
|
||||
assert out2.item() == 1 + 2 + 20
|
||||
# impl_a
|
||||
assert out3.item() == 1 + 2 + 10
|
||||
assert out4.item() == 1 + 2 + 10
|
||||
|
||||
def test_unsupported_impl_filtered(self):
|
||||
@_custom_add.register_impl("unsupported", supported=False)
|
||||
def impl_bad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 999
|
||||
|
||||
x = torch.tensor(1, dtype=torch.int32)
|
||||
y = torch.tensor(2, dtype=torch.int32)
|
||||
|
||||
with _custom_add.set_priority(["unsupported", "impl_a"]):
|
||||
assert _custom_add.get_priority() == ["impl_a"]
|
||||
out = _custom_add(x, y)
|
||||
|
||||
# impl_bad skipped → impl_a
|
||||
assert out.item() == 1 + 2 + 10
|
||||
|
||||
def test_supports_args_runtime_dispatch_and_warning(
|
||||
self, caplog_vllm: pytest.LogCaptureFixture
|
||||
):
|
||||
x1 = torch.ones((2, 2), dtype=torch.int32)
|
||||
y1 = torch.full((2, 2), 2, dtype=torch.int32)
|
||||
|
||||
x2 = torch.ones((2, 3), dtype=torch.int32)
|
||||
y2 = torch.full((2, 3), 2, dtype=torch.int32)
|
||||
|
||||
with (
|
||||
caplog_vllm.at_level(logging.WARNING),
|
||||
_custom_add.set_priority(["impl_even"]),
|
||||
):
|
||||
# Test the warning about native fallback is logged (before even dispatching)
|
||||
assert len(caplog_vllm.records) == 1
|
||||
message = caplog_vllm.records[0].message
|
||||
assert "_custom_add" in message
|
||||
assert "fallback to native" in message
|
||||
assert "priority" in message
|
||||
|
||||
# Check dispatching
|
||||
assert _custom_add.get_priority() == ["impl_even", "native"]
|
||||
assert _custom_add.dispatch(x1, y1) is impl_even
|
||||
assert _custom_add.dispatch(x2, y2) is _custom_add.impls["native"]
|
||||
|
||||
out1 = _custom_add(x1, y1) # size(1) == 2 → impl_even
|
||||
out2 = _custom_add(x2, y2) # size(1) == 3 → native fallback
|
||||
|
||||
# no other warnings
|
||||
assert len(caplog_vllm.records) == 1
|
||||
assert torch.all(out1 == 1 + 2 + 50)
|
||||
assert torch.all(out2 == 1 + 2)
|
||||
|
||||
def test_default_priority(
|
||||
self, caplog_vllm: pytest.LogCaptureFixture, disable_log_dedup
|
||||
):
|
||||
# Make sure logs are not deduplicated to properly test the warning
|
||||
x = torch.tensor([3], dtype=torch.int32)
|
||||
y = torch.tensor([4], dtype=torch.int32)
|
||||
|
||||
# No priority set → falls back to native
|
||||
assert _custom_add.get_priority() == []
|
||||
with caplog_vllm.at_level(logging.WARNING):
|
||||
# Native by default
|
||||
assert _custom_add.dispatch(x, y) is _custom_add.impls["native"]
|
||||
out = _custom_add(x, y)
|
||||
|
||||
# Check dispatching to native by default
|
||||
assert out.item() == 3 + 4
|
||||
|
||||
# Check warning
|
||||
assert len(caplog_vllm.records) == 2
|
||||
message = caplog_vllm.records[0].message.lower()
|
||||
assert "_custom_add" in message
|
||||
assert "priority not set" in message
|
||||
|
||||
|
||||
@vllm.ir.register_op
|
||||
def _custom_mm(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
tmp = x @ y
|
||||
return tmp if bias is None else tmp + bias
|
||||
|
||||
|
||||
def test_default_args():
|
||||
# Test that default args are properly applied when dispatching and calling
|
||||
@_custom_mm.register_impl("impl_mm", supports_args=lambda x, y, bias=None: True)
|
||||
def impl_mm(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
tmp = x @ y
|
||||
return tmp + 50 if bias is None else tmp + bias + 100
|
||||
|
||||
x1 = torch.tensor([1, 2], dtype=torch.int32)
|
||||
x2 = torch.tensor([3, 4], dtype=torch.int32)
|
||||
|
||||
# Test that supports_args receives the defaulted args
|
||||
assert impl_mm.supports_args(x1, x2)
|
||||
with _custom_mm.set_priority(["impl_mm", "native"]):
|
||||
assert _custom_mm.dispatch(x1, x2) is impl_mm
|
||||
|
||||
|
||||
def test_bad_impl_registrations():
|
||||
# Check bad schema
|
||||
with pytest.raises(ValueError, match="does not match native schema"):
|
||||
|
||||
@_custom_mm.register_impl("impl_mm_bad_schema")
|
||||
def impl_mm_bad_schema(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x @ y - 1
|
||||
|
||||
with pytest.raises(ValueError, match="does not match native schema"):
|
||||
|
||||
@_custom_mm.register_impl("impl_mm_bad_schema_2")
|
||||
def impl_mm_bad_schema_2(
|
||||
x: torch.Tensor, y: torch.Tensor, b: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + b - 2
|
||||
|
||||
with pytest.raises(ValueError, match="does not match native schema"):
|
||||
|
||||
@_custom_mm.register_impl("impl_mm_bad_schema_3")
|
||||
def impl_mm_bad_schema_3(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return x @ y + bias - 5
|
||||
|
||||
# check supports_args with incorrect params
|
||||
with pytest.raises(ValueError, match="supports_args must be a callable"):
|
||||
|
||||
@_custom_mm.register_impl("impl_mm_bad_supports_args", supports_args=True)
|
||||
def impl_mm_bad_supports_args(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + 10
|
||||
|
||||
with pytest.raises(ValueError, match="number of parameters"):
|
||||
|
||||
@_custom_mm.register_impl(
|
||||
"impl_mm_bad_supports_args_2", supports_args=lambda x, y: True
|
||||
)
|
||||
def impl_mm_bad_supports_args(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + 10
|
||||
|
||||
with pytest.raises(ValueError, match="keyword-only parameters"):
|
||||
|
||||
@_custom_mm.register_impl(
|
||||
"impl_mm_bad_supports_args_3", supports_args=lambda x, y, *, b: True
|
||||
)
|
||||
def impl_mm_bad_supports_args_2(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + 20
|
||||
|
||||
with pytest.raises(ValueError, match="does not match native parameter"):
|
||||
|
||||
@_custom_mm.register_impl(
|
||||
"impl_mm_bad_supports_args_4", supports_args=lambda x, y, b: True
|
||||
)
|
||||
def impl_mm_bad_supports_args_4(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + 30
|
||||
|
||||
with pytest.raises(ValueError, match="does not match native default"):
|
||||
|
||||
@_custom_mm.register_impl(
|
||||
"impl_mm_bad_supports_args_5", supports_args=lambda x, y, bias=1: True
|
||||
)
|
||||
def impl_mm_bad_supports_args_5(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + 40
|
||||
|
||||
assert set(_custom_mm.impls.keys()) == {"impl_mm", "native"}
|
||||
|
||||
|
||||
IMPL_OOT_SRC = """
|
||||
import torch
|
||||
|
||||
@_custom_mm.register_impl("impl_mm_oot")
|
||||
def impl_mm_oot(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y - 99
|
||||
"""
|
||||
|
||||
|
||||
def load_custom_mm_module(file_path: Path):
|
||||
spec = importlib.util.spec_from_file_location("_custom_mm_oot", file_path)
|
||||
assert spec is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# Inject the variable into the module's global namespace
|
||||
# This allows the @_custom_mm.register_impl decorator to work
|
||||
module._custom_mm = _custom_mm # type: ignore[attr-defined]
|
||||
|
||||
# Execute the file; this triggers the decorator
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_uuid_and_oot(tmp_path: Path):
|
||||
file_path = tmp_path / "_custom_mm_oot.py"
|
||||
file_path.write_text(IMPL_OOT_SRC)
|
||||
|
||||
assert "impl_mm_oot" not in _custom_mm.impls
|
||||
_ = load_custom_mm_module(file_path)
|
||||
assert "impl_mm_oot" in _custom_mm.impls
|
||||
|
||||
uuid = _custom_mm.impls["impl_mm_oot"].uuid()
|
||||
del _custom_mm.impls["impl_mm_oot"]
|
||||
|
||||
# Replace file source
|
||||
file_path.write_text(IMPL_OOT_SRC + " # added file source")
|
||||
assert "impl_mm_oot" not in _custom_mm.impls
|
||||
_ = load_custom_mm_module(file_path)
|
||||
assert "impl_mm_oot" in _custom_mm.impls
|
||||
|
||||
uuid1 = _custom_mm.impls["impl_mm_oot"].uuid()
|
||||
assert uuid1 != uuid
|
||||
del _custom_mm.impls["impl_mm_oot"]
|
||||
|
||||
# Back to original
|
||||
file_path.write_text(IMPL_OOT_SRC)
|
||||
assert "impl_mm_oot" not in _custom_mm.impls
|
||||
_ = load_custom_mm_module(file_path)
|
||||
assert "impl_mm_oot" in _custom_mm.impls
|
||||
|
||||
uuid2 = _custom_mm.impls["impl_mm_oot"].uuid()
|
||||
assert uuid2 == uuid
|
||||
assert uuid2 != uuid1
|
||||
del _custom_mm.impls["impl_mm_oot"]
|
||||
129
tests/kernels/ir/test_layernorm.py
Normal file
129
tests/kernels/ir/test_layernorm.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# This registers op implementations
|
||||
import vllm.kernels # noqa: F401
|
||||
from tests.kernels.allclose_default import get_default_rtol
|
||||
from vllm import ir
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype):
|
||||
x = torch.randn(n_tokens, hidden_size, dtype=dtype)
|
||||
weight = torch.rand(hidden_size, dtype=dtype)
|
||||
return x, weight
|
||||
|
||||
|
||||
rms_norm_native = ir.ops.rms_norm.impls["native"].impl_fn
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
|
||||
reason="Currently only kernels on CUDA, ROCm and XPU",
|
||||
)
|
||||
def test_rms_norm_registration():
|
||||
expected = {
|
||||
"native": True,
|
||||
"vllm_c": current_platform.is_cuda_alike(),
|
||||
"aiter": current_platform.is_rocm(),
|
||||
"oink": False,
|
||||
"xpu_kernels": current_platform.is_xpu(),
|
||||
}
|
||||
|
||||
actual = {
|
||||
provider: impl.supported for provider, impl in ir.ops.rms_norm.impls.items()
|
||||
}
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("n_tokens", [1, 8, 17])
|
||||
@pytest.mark.parametrize("hidden_size", [16, 4096, 8192])
|
||||
@pytest.mark.parametrize("epsilon", [1e-6, 1e-5])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
|
||||
reason="Currently only kernels on CUDA, ROCm and XPU",
|
||||
)
|
||||
class TestRMSNorm:
|
||||
@classmethod
|
||||
def setup_class(cls, **kwargs):
|
||||
torch.set_default_device(current_platform.device_type)
|
||||
|
||||
def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon):
|
||||
x, weight = rms_norm_inputs(4, 8, dtype)
|
||||
out = rms_norm_native(x, weight, epsilon=epsilon)
|
||||
|
||||
# Check shape, dtype, device
|
||||
assert out.shape == x.shape
|
||||
assert out.dtype == x.dtype
|
||||
assert out.device == x.device
|
||||
|
||||
# Check the scaling property of rms norm
|
||||
out2 = rms_norm_native(x * 2.0, weight, epsilon=epsilon)
|
||||
torch.testing.assert_close(out2, out, rtol=get_default_rtol(out), atol=1e-3)
|
||||
|
||||
# Check behavior with and without weight
|
||||
weight1 = torch.ones_like(weight)
|
||||
out3 = rms_norm_native(x, weight1, epsilon=epsilon)
|
||||
out4 = rms_norm_native(x, None, epsilon=epsilon)
|
||||
torch.testing.assert_close(out3, out4)
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels"])
|
||||
def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider):
|
||||
impl = ir.ops.rms_norm.impls[provider]
|
||||
if not impl.supported:
|
||||
pytest.skip(f"{provider} impl not supported on this platform")
|
||||
|
||||
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
|
||||
args = (x, weight, epsilon, None)
|
||||
|
||||
assert impl.supported
|
||||
|
||||
if provider == "aiter" and dtype not in [torch.float16, torch.bfloat16]:
|
||||
assert not impl.supports_args(*args)
|
||||
return
|
||||
|
||||
assert impl.supports_args(*args)
|
||||
|
||||
out_impl = impl.impl_fn(*args)
|
||||
out_native = rms_norm_native(*args)
|
||||
|
||||
torch.testing.assert_close(
|
||||
out_impl, out_native, rtol=get_default_rtol(out_impl), atol=1e-3
|
||||
)
|
||||
|
||||
# check that dispatched call matches direct call
|
||||
with ir.ops.rms_norm.set_priority([provider, "native"]):
|
||||
out_impl2 = ir.ops.rms_norm(*args)
|
||||
|
||||
# exact match
|
||||
torch.testing.assert_close(out_impl2, out_impl, rtol=0.0, atol=0.0)
|
||||
|
||||
# none of these support variance_size override
|
||||
assert not impl.supports_args(x, weight, epsilon, 4)
|
||||
assert not impl.supports_args(x, weight, epsilon, variance_size=4)
|
||||
|
||||
# test weight=None behavior
|
||||
out_impl_no_weight = impl.impl_fn(x, None, epsilon)
|
||||
out_impl_unit_weight = impl.impl_fn(x, torch.ones_like(weight), epsilon)
|
||||
torch.testing.assert_close(
|
||||
out_impl_no_weight,
|
||||
out_impl_unit_weight,
|
||||
rtol=get_default_rtol(out_impl_no_weight),
|
||||
atol=2e-4,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels", "native"])
|
||||
def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider):
|
||||
if not ir.ops.rms_norm.impls[provider].supported:
|
||||
pytest.skip(f"{provider} impl not supported on this platform")
|
||||
|
||||
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
|
||||
args = (x, weight, epsilon, None)
|
||||
|
||||
# When checking the torch op, we have to set priority and use dispatch
|
||||
with ir.ops.rms_norm.set_priority([provider, "native"]):
|
||||
torch.library.opcheck(torch.ops.vllm_ir.rms_norm, args)
|
||||
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm,
|
||||
dispatch_rocm_rmsnorm_func,
|
||||
fused_add_rms_norm,
|
||||
rms_norm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -156,7 +155,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
|
||||
assert topk_func == vllm_topk_sigmoid
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("add_residual", [False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
@@ -165,7 +164,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
|
||||
def test_rms_norm_dispatch(
|
||||
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
|
||||
):
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(dtype, use_rocm_aiter)
|
||||
|
||||
should_use_rocm_aiter = (
|
||||
current_platform.is_rocm()
|
||||
@@ -173,11 +172,7 @@ def test_rms_norm_dispatch(
|
||||
and dtype in RMS_NORM_SUPPORTED_DTYPES
|
||||
)
|
||||
|
||||
if add_residual and should_use_rocm_aiter:
|
||||
if should_use_rocm_aiter:
|
||||
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
|
||||
elif should_use_rocm_aiter:
|
||||
assert rms_norm_func == rocm_aiter_ops.rms_norm
|
||||
elif add_residual:
|
||||
assert rms_norm_func == fused_add_rms_norm
|
||||
else:
|
||||
assert rms_norm_func == rms_norm
|
||||
assert rms_norm_func == fused_add_rms_norm
|
||||
|
||||
@@ -6,12 +6,14 @@ import os
|
||||
from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||
from unittest.mock import patch
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
KernelConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
PoolerConfig,
|
||||
@@ -21,6 +23,7 @@ from vllm.config import (
|
||||
update_config,
|
||||
)
|
||||
from vllm.config.compilation import CompilationMode, CUDAGraphMode
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.utils import get_field
|
||||
from vllm.config.vllm import (
|
||||
@@ -1077,6 +1080,39 @@ def test_vllm_config_explicit_overrides():
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
|
||||
|
||||
def test_fusion_pass_op_priority():
|
||||
"""This test checks that custom op enablement & IR op priority
|
||||
correctly control default fusions"""
|
||||
|
||||
# Default config, O2, rms_norm+quant fusion disabled
|
||||
cfg1 = VllmConfig()
|
||||
assert not cfg1.compilation_config.pass_config.fuse_norm_quant
|
||||
|
||||
# rms_norm manually enabled, O1, rms_norm+quant fusion enabled
|
||||
cfg2 = VllmConfig(
|
||||
optimization_level=OptimizationLevel.O1,
|
||||
compilation_config=CompilationConfig(
|
||||
custom_ops=["+rms_norm"],
|
||||
),
|
||||
)
|
||||
assert cfg2.compilation_config.pass_config.fuse_norm_quant
|
||||
|
||||
# using custom kernel for RMSNorm via IR:
|
||||
# Note that vLLM IR only supports the non-residual rms_norm for now;
|
||||
# soon this will be resolved.
|
||||
cfg3 = VllmConfig(
|
||||
kernel_config=KernelConfig(
|
||||
ir_op_priority=IrOpPriorityConfig(rms_norm=["vllm_c"])
|
||||
)
|
||||
)
|
||||
assert cfg3.compilation_config.pass_config.fuse_norm_quant
|
||||
|
||||
# block-fp8 model should enable quant_fp8 automatically
|
||||
cfg4 = VllmConfig(model_config=ModelConfig("Qwen/Qwen3-4B-FP8"))
|
||||
assert "+quant_fp8" in cfg4.compilation_config.custom_ops
|
||||
assert cfg4.compilation_config.pass_config.fuse_norm_quant
|
||||
|
||||
|
||||
def test_scheduler_config_init():
|
||||
with pytest.raises(ValidationError):
|
||||
# Positional InitVars missing
|
||||
@@ -1171,3 +1207,35 @@ def test_eagle_draft_model_config():
|
||||
assert draft_model_config.hf_text_config.model_type == "eagle"
|
||||
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
|
||||
assert draft_model_config.architecture == "EagleLlamaForCausalLM"
|
||||
|
||||
|
||||
def test_ir_op_priority_default():
|
||||
"""Test that IR op priority defaults are set correctly."""
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
|
||||
# Assert default is applied to ops
|
||||
priority_config = IrOpPriorityConfig.with_default(["vllm_c", "native"])
|
||||
assert priority_config.rms_norm == ["vllm_c", "native"]
|
||||
|
||||
# Assert single ops override the default
|
||||
assert IrOpPriorityConfig.with_default(
|
||||
["vllm_c", "native"], rms_norm=["oink", "native"]
|
||||
) == IrOpPriorityConfig(rms_norm=["oink", "native"])
|
||||
|
||||
|
||||
def test_ir_op_priority_str():
|
||||
"""Test that passing a comma-delimited string works"""
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
|
||||
priority_config = IrOpPriorityConfig(rms_norm="vllm_c")
|
||||
assert priority_config.rms_norm == ["vllm_c"]
|
||||
|
||||
priority_config = IrOpPriorityConfig(rms_norm="vllm_c,native")
|
||||
assert priority_config.rms_norm == ["vllm_c", "native"]
|
||||
|
||||
priority_config = IrOpPriorityConfig(rms_norm=" native, vllm_c ")
|
||||
assert priority_config.rms_norm == ["native", "vllm_c"]
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
# must be list of only strings
|
||||
priority_config = IrOpPriorityConfig(rms_norm=["vllm_c", 4, "native"])
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import contextlib
|
||||
from importlib.util import find_spec
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@@ -10,6 +11,7 @@ import torch.fx as fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
import vllm.ir.ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
@@ -28,7 +30,7 @@ from vllm.utils.torch_utils import (
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
@@ -258,6 +260,12 @@ class BasePattern:
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs)
|
||||
|
||||
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
|
||||
|
||||
|
||||
class AllReduceRMSNormPattern(BasePattern):
|
||||
"""
|
||||
@@ -276,20 +284,17 @@ class AllReduceRMSNormPattern(BasePattern):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight]
|
||||
# input, weight
|
||||
return [self.empty(5, 16), self.empty(16)]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(allreduce_output, weight)
|
||||
rms = vllm.ir.ops.rms_norm(allreduce_output, weight, self.epsilon)
|
||||
|
||||
return rms, allreduce_output
|
||||
|
||||
@@ -407,15 +412,13 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight, scale]
|
||||
# input, weight
|
||||
return [self.empty(5, 16), self.empty(16), scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
@@ -424,7 +427,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
@@ -553,7 +556,6 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
|
||||
@@ -575,7 +577,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
output_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
input=rms,
|
||||
|
||||
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
ROTARY_OP = torch.ops._C.rotary_embedding.default
|
||||
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
|
||||
@@ -160,69 +159,6 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
|
||||
return result
|
||||
|
||||
|
||||
class MatcherRMSNorm(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
self._rmsnorm_op = RMS_OP
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
return [input, weight]
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self._rmsnorm_op(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, weight)
|
||||
|
||||
result = torch.empty_like(input)
|
||||
_, result = auto_functionalized(
|
||||
self._rmsnorm_op,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight
|
||||
)
|
||||
|
||||
|
||||
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -10,6 +10,7 @@ from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
import vllm.ir.ops
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
@@ -17,7 +18,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
|
||||
from .matcher_utils import MatcherRotaryEmbedding
|
||||
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -64,7 +65,6 @@ class QkNormRopePattern:
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.eps = eps
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(eps)
|
||||
self.is_neox = is_neox
|
||||
self.rope_flashinfer = rope_flashinfer
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
@@ -129,14 +129,14 @@ class QkNormRopePattern:
|
||||
q_by_head = q.view(
|
||||
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
|
||||
q_normed_by_head = vllm.ir.ops.rms_norm(q_by_head, q_weight, self.eps)
|
||||
q_flat = q_normed_by_head.view(q.shape)
|
||||
|
||||
# K path: view -> RMS -> view back to k.shape
|
||||
k_by_head = k.view(
|
||||
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
|
||||
k_normed_by_head = vllm.ir.ops.rms_norm(k_by_head, k_weight, self.eps)
|
||||
k_flat = k_normed_by_head.view(k.shape)
|
||||
|
||||
# RoPE: apply to flattened q/k
|
||||
|
||||
@@ -9,6 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.ir.ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -30,7 +31,6 @@ from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -54,7 +54,6 @@ def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
|
||||
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
@@ -131,11 +130,9 @@ class RMSNormQuantPattern:
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
if key.fused_add:
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
@@ -161,16 +158,12 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
|
||||
return self.quant_matcher(result_rms, scale)[0]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
@@ -187,8 +180,8 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
# input, weight
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
empty_bf16(5, 16), # input
|
||||
empty_bf16(16), # weight
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
pattern(*inputs)
|
||||
@@ -391,7 +384,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
|
||||
result = torch.empty(
|
||||
result_rms.shape,
|
||||
device=result_rms.device,
|
||||
@@ -442,12 +435,14 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
scale = self.quant_matcher.empty_f32(1, 1)
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs() + [scale],
|
||||
[
|
||||
empty_bf16(5, 16), # input
|
||||
empty_bf16(16), # weight
|
||||
self.quant_matcher.empty_f32(1, 1), # scale
|
||||
],
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
@@ -472,7 +467,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
|
||||
# result, scale
|
||||
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
|
||||
|
||||
@@ -502,7 +497,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
[
|
||||
empty_bf16(5, 16), # input
|
||||
empty_bf16(16), # weight
|
||||
],
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@@ -24,7 +25,6 @@ from .act_quant_fusion import ActivationQuantPattern
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
MatcherSiluAndMul,
|
||||
)
|
||||
from .rms_quant_fusion import (
|
||||
@@ -41,17 +41,23 @@ class AiterRMSNormQuantPattern:
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
if key.fused_add:
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(
|
||||
epsilon, match_rocm_aiter=True
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant,
|
||||
match_rocm_aiter=match_aiter_quant,
|
||||
)
|
||||
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.bfloat16, device=self.device, **kwargs)
|
||||
|
||||
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
|
||||
|
||||
|
||||
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""AITER RMSNorm + Dynamic Quantization pattern."""
|
||||
@@ -79,7 +85,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
@@ -99,7 +105,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
# input, weight
|
||||
[self.empty(5, 16), self.empty(16)],
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
@@ -188,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
@@ -206,7 +213,12 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
return at[0], at[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||
pattern,
|
||||
replacement,
|
||||
# input, weight
|
||||
[self.empty(5, 16), self.empty(16)],
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
import vllm.ir.ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
@@ -22,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..utility.noop_elimination import NoOpEliminationPass
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -122,35 +123,38 @@ class _SequenceParallelPatternHelper:
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs)
|
||||
|
||||
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [input, arg3_1]
|
||||
# input, weight
|
||||
return [self.empty([1, 8, 4]), self.empty([4])]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
|
||||
rmsnorm = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
|
||||
|
||||
return rmsnorm, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
|
||||
rmsnorm = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
|
||||
all_gather = self._all_gather(rmsnorm)
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
@@ -222,14 +226,11 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
device: str | None,
|
||||
) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, weight, scale]
|
||||
# input, weight, scale
|
||||
return [self.empty([1, 8, 4]), self.empty([4]), self.empty_f32([1, 1])]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
@@ -238,7 +239,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
@@ -248,7 +249,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
rms = self.rmsnorm_matcher(reduce_scatter, weight)
|
||||
rms = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
|
||||
|
||||
0
vllm/compilation/passes/ir/__init__.py
Normal file
0
vllm/compilation/passes/ir/__init__.py
Normal file
158
vllm/compilation/passes/ir/lowering_pass.py
Normal file
158
vllm/compilation/passes/ir/lowering_pass.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
|
||||
from torch import fx
|
||||
from torch._inductor.pattern_matcher import (
|
||||
CallFunctionVarArgs,
|
||||
Match,
|
||||
PatternMatcherPass,
|
||||
register_graph_pattern,
|
||||
)
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.ir.op import IrOp
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import lazy
|
||||
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_default_overload(op: OpOverload | OpOverloadPacket) -> OpOverload:
|
||||
if isinstance(op, OpOverloadPacket):
|
||||
return op.default
|
||||
assert isinstance(op, OpOverload), "Expected an OpOverload or OpOverloadPacket"
|
||||
return op
|
||||
|
||||
|
||||
def get_ir_op(node: fx.Node) -> IrOp | None:
|
||||
if node.op != "call_function":
|
||||
return None
|
||||
|
||||
if not isinstance(node.target, (OpOverload, OpOverloadPacket)):
|
||||
return None
|
||||
|
||||
op_overload = get_default_overload(node.target)
|
||||
if op_overload.namespace != "vllm_ir":
|
||||
return None
|
||||
|
||||
op_name = op_overload._opname
|
||||
if op_name not in IrOp.registry:
|
||||
logger.warning(
|
||||
"Unknown vLLM IR op %s, there's likely an issue with torch registration, "
|
||||
"or a torch custom op was registered in the vllm_ir namespace by mistake.",
|
||||
op_name,
|
||||
)
|
||||
return None
|
||||
|
||||
ir_op = IrOp.registry[op_name]
|
||||
return ir_op
|
||||
|
||||
|
||||
class VllmIRLoweringPass(VllmInductorPass):
|
||||
"""
|
||||
This pass lowers vLLM IR ops to their implementations the priority list.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
super().__init__(vllm_config)
|
||||
self.patterns = PatternMatcherPass(self.pass_name)
|
||||
self.selected_impls: dict[str, dict[str, str]] = defaultdict(lambda: {})
|
||||
self.ops = [ir_op.torch_op for ir_op in IrOp.registry.values()]
|
||||
|
||||
# Look for any call_function node where the target is a vLLM IR op.
|
||||
# Then, lower_matched_op will select, trace, and insert the implementation.
|
||||
register_graph_pattern(
|
||||
CallFunctionVarArgs(self.ops),
|
||||
pass_dict=self.patterns,
|
||||
)(self.lower_matched_op)
|
||||
|
||||
def lower_matched_op(self, match: Match, *args, **kwargs):
|
||||
# TODO(luka) I think args and kwargs are for the match, but just use the node?
|
||||
|
||||
assert len(match.nodes) == 1, "Expected single node match"
|
||||
node = match.nodes[0]
|
||||
ir_op = get_ir_op(node)
|
||||
assert ir_op is not None, "Expected vLLM IR op"
|
||||
assert not node.kwargs # I think there should never be kwargs here
|
||||
|
||||
# Select and record the implementation, using fake args
|
||||
fake_args = fx.map_arg(node.args, lambda arg: arg.meta["val"])
|
||||
ir_op_impl = ir_op.dispatch(*fake_args)
|
||||
self.selected_impls[ir_op.name][node.name] = ir_op_impl.provider
|
||||
|
||||
# replace_by_example wants node args, not the fake tensors
|
||||
# TODO(luka): Use aot_export_module to get functionalized graph
|
||||
# TODO(luka): Cache the fx_replacement to avoid re-tracing the same impl
|
||||
|
||||
# Defaults not present on node.args but required for replacement tracing
|
||||
bound_args = ir_op._py_signature.bind(*node.args)
|
||||
bound_args.apply_defaults()
|
||||
match.replace_by_example(ir_op_impl.impl_fn, bound_args.args)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
# clear at the beginning instead of end, so that tests can inspect
|
||||
self.selected_impls.clear()
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("VllmIRLoweringPass lowered %d vLLM IR nodes", count)
|
||||
|
||||
# TODO write self.selected_impls to depyf/tlparse dir
|
||||
def count_items(impls: Iterable[str]) -> dict[str, int]:
|
||||
counts: dict[str, int] = defaultdict(lambda: 0)
|
||||
for impl in impls:
|
||||
counts[impl] += 1
|
||||
return counts
|
||||
|
||||
def print_count(counts: dict[str, int]) -> str:
|
||||
# e.g., "impl1*3,impl2"
|
||||
impl_count = lambda i, c: f"{i}" if c == 1 else f"{i}*{c}"
|
||||
return ",".join(impl_count(impl, count) for impl, count in counts.items())
|
||||
|
||||
logger.debug(
|
||||
"Selected implementations: %s",
|
||||
lazy(
|
||||
lambda: ", ".join(
|
||||
f"{op}={print_count(count_items(impls_by_node.values()))}"
|
||||
for op, impls_by_node in self.selected_impls.items()
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
failed_nodes: list[fx.Node] = []
|
||||
failed_ops: set[str] = set()
|
||||
# Check no vllm_ir nodes were left in the graph
|
||||
for node in graph.nodes:
|
||||
if (ir_op := get_ir_op(node)) is None:
|
||||
continue
|
||||
|
||||
failed_nodes.append(node)
|
||||
failed_ops.add(ir_op.name)
|
||||
|
||||
if failed_nodes or failed_ops:
|
||||
logger.warning("Failed to lower vLLM IR ops: %s", ",".join(failed_ops))
|
||||
logger.warning("Full node list: %s", failed_nodes)
|
||||
|
||||
def uuid(self) -> str:
|
||||
"""
|
||||
IR op priority & impl sources affect lowering pass output,
|
||||
so we include them in the cache key.
|
||||
"""
|
||||
priorities = {name: op.get_priority() for name, op in IrOp.registry.items()}
|
||||
priorities_str = ";".join(
|
||||
f"{name}={','.join(p)}" for name, p in priorities.items()
|
||||
)
|
||||
|
||||
impl_uuids_str = ";".join(
|
||||
f"{name}={
|
||||
','.join(IrOp.registry[name].impls[provider].uuid() for provider in p)
|
||||
}"
|
||||
for name, p in priorities.items()
|
||||
)
|
||||
|
||||
return f"{super().uuid()}|{priorities_str}|{impl_uuids_str}"
|
||||
@@ -14,6 +14,7 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import set_env_var
|
||||
|
||||
from .ir.lowering_pass import VllmIRLoweringPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
@@ -99,8 +100,17 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
else:
|
||||
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
# perform the first post-cleanup before IR lowering to clean up fusion artifacts
|
||||
# and make sure no dead IR ops are lowered.
|
||||
self.post_cleanup(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# lowering before cleanup so DCE can clean up lowered ops.
|
||||
# DCE handles mutating ops correctly as well.
|
||||
self.ir_lowering(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# clean up after lowering again
|
||||
self.post_cleanup(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
@@ -152,7 +162,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
self.passes += [QKNormRoPEFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.ir_lowering = VllmIRLoweringPass(config)
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
@@ -171,6 +181,10 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
|
||||
for pass_ in self.passes:
|
||||
passes.append(pass_.uuid())
|
||||
|
||||
passes.append(self.post_cleanup.uuid())
|
||||
passes.append(self.ir_lowering.uuid())
|
||||
passes.append(self.post_cleanup.uuid())
|
||||
passes.append(self.fix_functionalization.uuid())
|
||||
|
||||
# Include the compile range in the uuid to ensure that inductor
|
||||
|
||||
@@ -152,6 +152,7 @@ class VllmPatternMatcherPass(VllmInductorPass):
|
||||
f"auto_functionalized as auto_functionalized\n"
|
||||
f"from torch._inductor.pattern_matcher import *\n"
|
||||
f"vllm = torch.ops.vllm",
|
||||
"vllm_ir = torch.ops.vllm_ir",
|
||||
file=f,
|
||||
)
|
||||
|
||||
|
||||
@@ -466,6 +466,15 @@ class CompilationConfig:
|
||||
disabled when running with Inductor: mode>CompilationMode.NONE and
|
||||
backend="inductor".
|
||||
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
||||
|
||||
ir_enable_torch_wrap: bool = None # type: ignore[assignment]
|
||||
"""If True, enable vllm_ir torch custom op wrapping during the forward pass.
|
||||
When False, torch custom op wrapping is disabled, allowing Dynamo to trace the
|
||||
selected implementation directly or avoiding torch custom op overhead in eager mode.
|
||||
Defaults to True when using Inductor with vllm-compile
|
||||
(backend=="inductor" and mode == VLLM_COMPILE), False otherwise.
|
||||
"""
|
||||
|
||||
splitting_ops: list[str] | None = None
|
||||
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
|
||||
|
||||
@@ -830,6 +839,7 @@ class CompilationConfig:
|
||||
"cudagraph_mode",
|
||||
"max_cudagraph_capture_size",
|
||||
"use_inductor_graph_partition",
|
||||
"ir_enable_torch_wrap",
|
||||
mode="wrap",
|
||||
)
|
||||
@classmethod
|
||||
|
||||
@@ -1,13 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
from dataclasses import asdict, fields
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import field_validator
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from vllm.config.utils import config, get_hash_factors, hash_factors
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@config
|
||||
class IrOpPriorityConfig:
|
||||
"""
|
||||
Configuration for vLLM IR op priority for dispatching/lowering during the
|
||||
forward pass. Each member is a list of strings, which will be passed to
|
||||
vllm.ir.ops.<op_name>.set_priority() for the duration of the forward pass.
|
||||
A single comma-separated string is accepted as well,
|
||||
|
||||
If specified manually, platform defaults will be appended to the lists.
|
||||
See KernelConfig.set_platform_defaults().
|
||||
"""
|
||||
|
||||
rms_norm: list[str] = Field(default_factory=list)
|
||||
"""Priority list for vllm.ir.ops.rms_norm"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
Any future fields that don't affect compilation should be excluded.
|
||||
|
||||
Also, manually add IR op impl UUIDs to make sure they affect the compile cache.
|
||||
"""
|
||||
factors = get_hash_factors(self, set())
|
||||
|
||||
# Implementations are hidden from Dynamo,
|
||||
# so they don't show up in the traced files list.
|
||||
from vllm.ir.op import IrOp
|
||||
|
||||
assert "_impls" not in factors
|
||||
factors["_impls"] = {
|
||||
name: {
|
||||
provider: IrOp.registry[name].impls[provider].uuid() for provider in p
|
||||
}
|
||||
for name, p in asdict(self).items()
|
||||
}
|
||||
|
||||
return hash_factors(factors)
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def _to_list_str(cls, value: str | list[str]):
|
||||
if isinstance(value, str):
|
||||
value = value.replace(" ", "").split(",")
|
||||
|
||||
assert all(isinstance(v, str) for v in value)
|
||||
return value
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_priority(self):
|
||||
"""
|
||||
Context manager to set the IR op priority for all op members.
|
||||
It also imports vllm.kernels to ensure all implementations are made available.
|
||||
"""
|
||||
import vllm.kernels # noqa: F401, registers IR op implementations
|
||||
from vllm.ir.op import IrOp
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
for field in fields(self):
|
||||
op_priority = getattr(self, field.name)
|
||||
assert op_priority is not None, (
|
||||
f"IR op priority for {field.name} must be set"
|
||||
)
|
||||
logger.debug(
|
||||
"Setting IR op priority for %s to %s", field.name, op_priority
|
||||
)
|
||||
ir_op = IrOp.registry[field.name]
|
||||
stack.enter_context(ir_op.set_priority(op_priority))
|
||||
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
def with_default(
|
||||
cls, default: list[str], /, **kwargs: list[str]
|
||||
) -> "IrOpPriorityConfig":
|
||||
"""
|
||||
A helper to create an IrOpPriorityConfig where fields not specified in kwargs
|
||||
use the given default list.
|
||||
"""
|
||||
for field in fields(cls):
|
||||
if field.name not in kwargs:
|
||||
kwargs[field.name] = list(default)
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
MoEBackend = Literal[
|
||||
"auto",
|
||||
@@ -26,6 +119,12 @@ MoEBackend = Literal[
|
||||
class KernelConfig:
|
||||
"""Configuration for kernel selection and warmup behavior."""
|
||||
|
||||
ir_op_priority: IrOpPriorityConfig = Field(default_factory=IrOpPriorityConfig)
|
||||
"""
|
||||
vLLM IR op priority for dispatching/lowering during the forward pass.
|
||||
Platform defaults appended automatically during VllmConfig.__post_init__.
|
||||
"""
|
||||
|
||||
enable_flashinfer_autotune: bool = None # type: ignore[assignment]
|
||||
"""If True, run FlashInfer autotuning during kernel warmup."""
|
||||
|
||||
@@ -51,21 +150,17 @@ class KernelConfig:
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
Produces a hash unique to the pass configuration.
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
Any future fields that don't affect compilation should be excluded.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
ignored_factors = {
|
||||
"enable_flashinfer_autotune",
|
||||
"ir_op_priority", # handled separately below
|
||||
}
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
factors["ir_op_priority"] = self.ir_op_priority.compute_hash()
|
||||
return hash_factors(factors)
|
||||
|
||||
@field_validator("enable_flashinfer_autotune", mode="wrap")
|
||||
@classmethod
|
||||
@@ -74,3 +169,31 @@ class KernelConfig:
|
||||
if value is None:
|
||||
return value
|
||||
return handler(value)
|
||||
|
||||
def set_platform_defaults(self, vllm_config: "VllmConfig") -> None:
|
||||
"""Set platform-specific defaults for the kernel config."""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
platform_op_priority = current_platform.get_default_ir_op_priority(vllm_config)
|
||||
logger.debug(
|
||||
"Setting platform-specific IR op priority defaults: %s, user-defined: %s",
|
||||
platform_op_priority,
|
||||
self.ir_op_priority,
|
||||
)
|
||||
for op_name, op_priority in asdict(platform_op_priority).items():
|
||||
current_op_priority: list[str] = getattr(self.ir_op_priority, op_name)
|
||||
if current_op_priority is None:
|
||||
setattr(self.ir_op_priority, op_name, op_priority)
|
||||
else:
|
||||
# Append platform-specific priorities
|
||||
# Must be idempotent because vllm_config.set_platform_defaults() may be
|
||||
# called multiple times (due to VllmConfig.__post_init__ manual call).
|
||||
unique_op_priority = [
|
||||
op for op in op_priority if op not in current_op_priority
|
||||
]
|
||||
current_op_priority.extend(unique_op_priority)
|
||||
|
||||
logger.info(
|
||||
"Final IR op priority after setting platform defaults: %s",
|
||||
self.ir_op_priority,
|
||||
)
|
||||
|
||||
@@ -95,9 +95,11 @@ def enable_norm_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if either RMS norm or quant FP8 custom op is active;
|
||||
otherwise Inductor handles fusion."""
|
||||
|
||||
return cfg.compilation_config.is_custom_op_enabled(
|
||||
"rms_norm"
|
||||
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
||||
return (
|
||||
cfg.compilation_config.is_custom_op_enabled("rms_norm")
|
||||
or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
||||
or cfg.kernel_config.ir_op_priority.rms_norm[0] != "native"
|
||||
)
|
||||
|
||||
|
||||
def enable_act_fusion(cfg: "VllmConfig") -> bool:
|
||||
@@ -417,6 +419,10 @@ class VllmConfig:
|
||||
vllm_factors.append(self.compilation_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.kernel_config:
|
||||
vllm_factors.append(self.kernel_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append(None)
|
||||
if self.kv_transfer_config:
|
||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||
else:
|
||||
@@ -890,6 +896,13 @@ class VllmConfig:
|
||||
else:
|
||||
self.compilation_config.mode = CompilationMode.NONE
|
||||
|
||||
# By default, enable torch wrapping only when using custom Inductor lowering
|
||||
if self.compilation_config.ir_enable_torch_wrap is None:
|
||||
self.compilation_config.ir_enable_torch_wrap = (
|
||||
self.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
and self.compilation_config.backend == "inductor"
|
||||
)
|
||||
|
||||
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
|
||||
if (
|
||||
self.compilation_config.backend == "inductor"
|
||||
@@ -899,6 +912,11 @@ class VllmConfig:
|
||||
else:
|
||||
self.compilation_config.custom_ops.append("all")
|
||||
|
||||
# This populates IR op priorities,
|
||||
# must happen after compilation mode and backend are decided,
|
||||
# but before fusion defaults are applied as those may depend on op priority.
|
||||
self.kernel_config.set_platform_defaults(self)
|
||||
|
||||
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
|
||||
self._apply_optimization_level_defaults(default_config)
|
||||
if self.kernel_config.enable_flashinfer_autotune is None:
|
||||
@@ -1706,7 +1724,8 @@ class VllmConfig:
|
||||
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
||||
f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa
|
||||
f"pooler_config={self.model_config.pooler_config!r}, "
|
||||
f"compilation_config={self.compilation_config!r}"
|
||||
f"compilation_config={self.compilation_config!r}, "
|
||||
f"kernel_config={self.kernel_config!r}"
|
||||
)
|
||||
|
||||
def validate_block_size(self) -> None:
|
||||
|
||||
@@ -8,7 +8,7 @@ import functools
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import MISSING, dataclass, fields, is_dataclass
|
||||
from dataclasses import MISSING, asdict, dataclass, fields, is_dataclass
|
||||
from itertools import permutations
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
@@ -70,7 +70,7 @@ from vllm.config.cache import (
|
||||
PrefixCachingHashAlgo,
|
||||
)
|
||||
from vllm.config.device import Device
|
||||
from vllm.config.kernel import MoEBackend
|
||||
from vllm.config.kernel import IrOpPriorityConfig, MoEBackend
|
||||
from vllm.config.lora import MaxLoRARanks
|
||||
from vllm.config.model import (
|
||||
ConvertOption,
|
||||
@@ -401,6 +401,7 @@ class EngineArgs:
|
||||
max_cudagraph_capture_size: int | None = get_field(
|
||||
CompilationConfig, "max_cudagraph_capture_size"
|
||||
)
|
||||
ir_op_priority: IrOpPriorityConfig = get_field(KernelConfig, "ir_op_priority")
|
||||
# Note: Specifying a custom executor backend by passing a class
|
||||
# is intended for expert use only. The API may change without
|
||||
# notice.
|
||||
@@ -657,6 +658,9 @@ class EngineArgs:
|
||||
self.weight_transfer_config = WeightTransferConfig(
|
||||
**self.weight_transfer_config
|
||||
)
|
||||
if isinstance(self.ir_op_priority, dict):
|
||||
self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority)
|
||||
|
||||
# Setup plugins
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
||||
@@ -1293,6 +1297,7 @@ class EngineArgs:
|
||||
title="KernelConfig",
|
||||
description=KernelConfig.__doc__,
|
||||
)
|
||||
kernel_group.add_argument("--ir-op-priority", **kernel_kwargs["ir_op_priority"])
|
||||
kernel_group.add_argument(
|
||||
"--enable-flashinfer-autotune",
|
||||
**kernel_kwargs["enable_flashinfer_autotune"],
|
||||
@@ -1917,6 +1922,22 @@ class EngineArgs:
|
||||
if self.moe_backend != "auto":
|
||||
kernel_config.moe_backend = self.moe_backend
|
||||
|
||||
# Transfer top-level ir_op_priority into KernelConfig.ir_op_priority
|
||||
for op_name, op_priority in asdict(self.ir_op_priority).items():
|
||||
# Empty means unset
|
||||
if not op_priority:
|
||||
continue
|
||||
|
||||
# Priority cannot be set 2x for the same op
|
||||
if getattr(kernel_config.ir_op_priority, op_name):
|
||||
raise ValueError(
|
||||
f"Op priority for {op_name} specified via both ir_op_priority "
|
||||
f"and KernelConfig.ir_op_priority, only one allowed at a time."
|
||||
)
|
||||
|
||||
# Set the attribute
|
||||
setattr(kernel_config.ir_op_priority, op_name, op_priority)
|
||||
|
||||
load_config = self.create_load_config()
|
||||
|
||||
# Pass reasoning_parser into StructuredOutputsConfig
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.ir
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -378,7 +379,13 @@ def set_forward_context(
|
||||
)
|
||||
|
||||
try:
|
||||
with override_forward_context(forward_context):
|
||||
with (
|
||||
override_forward_context(forward_context),
|
||||
vllm_config.kernel_config.ir_op_priority.set_priority(),
|
||||
vllm.ir.enable_torch_wrap(
|
||||
vllm_config.compilation_config.ir_enable_torch_wrap
|
||||
),
|
||||
):
|
||||
yield
|
||||
finally:
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
|
||||
6
vllm/ir/__init__.py
Normal file
6
vllm/ir/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from . import ops
|
||||
from .op import enable_torch_wrap, register_op
|
||||
|
||||
__all__ = ["enable_torch_wrap", "register_op", "ops"]
|
||||
414
vllm/ir/op.py
Normal file
414
vllm/ir/op.py
Normal file
@@ -0,0 +1,414 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, overload
|
||||
|
||||
import torch
|
||||
from torch.library import Library, infer_schema
|
||||
|
||||
from vllm.ir.util import hash_source, weak_cache
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import lazy, tensors_str_no_data
|
||||
|
||||
vllm_ir_lib = Library("vllm_ir", "FRAGMENT")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
RESERVED_PROVIDERS = ["native", "unfused"]
|
||||
"""Providers that are reserved and cannot be used for custom implementations."""
|
||||
|
||||
_ENABLE_TORCH_WRAP: bool = True
|
||||
"""Global override flag to control torch op layer wrapping."""
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_torch_wrap(enable: bool = True):
|
||||
"""
|
||||
Context manager to enable/disable torch custom op wrapping for vLLM IR ops.
|
||||
When torch wrapping is disabled, the torch custom op layer is skipped
|
||||
and IR ops dispatch directly to the implementation.
|
||||
Helpful for avoiding torch dispatch overhead in eager mode
|
||||
and avoiding the need for lowering for platforms not using Inductor.
|
||||
"""
|
||||
|
||||
global _ENABLE_TORCH_WRAP
|
||||
old = _ENABLE_TORCH_WRAP
|
||||
try:
|
||||
_ENABLE_TORCH_WRAP = enable
|
||||
yield
|
||||
finally:
|
||||
_ENABLE_TORCH_WRAP = old
|
||||
|
||||
|
||||
# 0-param decorator overload
|
||||
@overload
|
||||
def register_op(f: Callable[..., Any]) -> "IrOp": ...
|
||||
|
||||
|
||||
# parametrized decorator overload
|
||||
@overload
|
||||
def register_op(
|
||||
*,
|
||||
name: str | None = None,
|
||||
) -> Callable[[Callable[..., Any]], "IrOp"]: ...
|
||||
|
||||
|
||||
def register_op(
|
||||
f: Callable | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
) -> "IrOp | Callable[[Callable], IrOp]":
|
||||
"""
|
||||
Register a new vLLM IR op.
|
||||
|
||||
:param f: the native implementation of the op
|
||||
:param name: the name of the op, defaults to the function name
|
||||
:return: the IrOp object if f is provided, otherwise a decorator
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
@vllm.ir.register_op
|
||||
def my_op(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y
|
||||
|
||||
|
||||
@vllm.ir.register_op(name="custom_mul")
|
||||
def multiply(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x * y"""
|
||||
|
||||
def decorator(_f: Callable):
|
||||
op_name: str = _f.__name__ if name is None else name
|
||||
assert op_name not in IrOp.registry
|
||||
op = IrOp(op_name, _f)
|
||||
IrOp.registry[op_name] = op
|
||||
return op
|
||||
|
||||
if f is not None:
|
||||
return decorator(f)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class IrOp:
|
||||
registry: ClassVar[dict[str, "IrOp"]] = {}
|
||||
|
||||
name: str
|
||||
impls: dict[str, "IrOpImpl"]
|
||||
|
||||
def __init__(self, name: str, native_impl: Callable):
|
||||
self._py_signature = inspect.signature(native_impl)
|
||||
if any(
|
||||
p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
for p in self._py_signature.parameters.values()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Op {name} has keyword-only arguments which are not currently "
|
||||
f"supported. That's because kwargs are not allowed during lowering."
|
||||
)
|
||||
|
||||
self.name = name
|
||||
self.impls: dict[str, IrOpImpl] = {}
|
||||
self._priority_impls: list[IrOpImpl] = []
|
||||
self._schema_str = infer_schema(native_impl, mutates_args=[])
|
||||
|
||||
# native implementation
|
||||
self.impls["native"] = IrOpImpl(
|
||||
self, "native", native_impl, supported=True, supports_args=None
|
||||
)
|
||||
|
||||
# By default, fake routes directly to native,
|
||||
# can be overridden by register_fake
|
||||
self._fake_fn = native_impl
|
||||
|
||||
# torch registration
|
||||
vllm_ir_lib.define(self.name + self._schema_str)
|
||||
# CompositeExplicitAutograd is not decomposed
|
||||
# by ATen IR normalization in AOTAutograd
|
||||
vllm_ir_lib.impl(
|
||||
self.name, self._inner_call, dispatch_key="CompositeExplicitAutograd"
|
||||
)
|
||||
vllm_ir_lib._register_fake(self.name, self._fake_call)
|
||||
assert hasattr(torch.ops.vllm_ir, name)
|
||||
self.torch_op: torch._ops.OpOverload = getattr(torch.ops.vllm_ir, name).default
|
||||
|
||||
def register_fake(self, fn: Callable) -> Callable:
|
||||
"""
|
||||
Register a fake impl for the torch custom op. If this method is not called,
|
||||
the native implementation is used directly for the fake implementation.
|
||||
"""
|
||||
self._fake_fn = fn
|
||||
return fn
|
||||
|
||||
def _fake_call(self, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Call to the fake implementation of the op. We use indirection because we want
|
||||
users to be able to register fake later but also want it to fall back to native
|
||||
directly by default, instead of going through the dispatching mechanism.
|
||||
"""
|
||||
return self._fake_fn(*args, **kwargs)
|
||||
|
||||
def register_impl(
|
||||
self,
|
||||
provider: str,
|
||||
*,
|
||||
supported: bool = True,
|
||||
supports_args: Callable[..., bool] | None = None,
|
||||
):
|
||||
"""
|
||||
Register an implementation for this custom op.
|
||||
:param provider: The name of the provider, must be unique.
|
||||
:param supported: Static support check, use this to check platform support.
|
||||
:param supports_args: Dynamic arg support check, used for types and shapes.
|
||||
:return: A decorator that registers the implementation.
|
||||
|
||||
The decorated function must have the same semantics and signature as
|
||||
the native implementation.
|
||||
|
||||
The provider name must be unique and not one of the RESERVED_PROVIDERS.
|
||||
The supported and supports_args parameters should not be used to implement
|
||||
custom enablement logic based on global state (e.g. environment variables).
|
||||
Instead, supported param should only be used to check for platform support
|
||||
(e.g. whether a specific hardware or library is available).
|
||||
supports_args should be used to check whether the provided arguments are
|
||||
compatible with the implementation.
|
||||
For custom enablement logic, set op impl priority.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@my_op.register_impl("my_provider", supported=torch.cuda.is_available())
|
||||
def my_provider_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
|
||||
```
|
||||
|
||||
"""
|
||||
assert provider not in RESERVED_PROVIDERS, (
|
||||
f"Provider name {provider} is reserved."
|
||||
)
|
||||
|
||||
def _register_impl(f: Callable):
|
||||
impl = IrOpImpl(self, provider, f, supported, supports_args)
|
||||
self.impls[provider] = impl
|
||||
|
||||
if self.get_priority():
|
||||
logger.warning(
|
||||
"Warning: registering new impl %s for op %s while priority is set.",
|
||||
provider,
|
||||
self.name,
|
||||
)
|
||||
|
||||
return impl
|
||||
|
||||
return _register_impl
|
||||
|
||||
def _inner_call(self, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Eager call to torch op lands here. When torch wrapping is disabled,
|
||||
__call__ routes straight here instead of going through torch op dispatching.
|
||||
"""
|
||||
impl = self.dispatch(*args, **kwargs)
|
||||
return impl.impl_fn(*args, **kwargs)
|
||||
|
||||
def apply_arg_defaults(self, args) -> tuple:
|
||||
"""
|
||||
Return args with default values applied.
|
||||
Defaults are taken from the native implementation signature.
|
||||
|
||||
SHOULD NOT BE USED IN THE DISPATCH PATH (SLOW).
|
||||
Only for Inductor lowering.
|
||||
"""
|
||||
bound_args = self._py_signature.bind(*args)
|
||||
bound_args.apply_defaults()
|
||||
return bound_args.args
|
||||
|
||||
def dispatch(self, *args, **kwargs) -> "IrOpImpl":
|
||||
"""
|
||||
Dispatch to the appropriate implementation based on current priority
|
||||
and argument support checks. Returns the selected IrOpImpl.
|
||||
|
||||
THIS FUNCTION IS ON THE HOT PATH (OP DISPATCH), MUST BE FAST.
|
||||
"""
|
||||
if not self._priority_impls:
|
||||
if not torch.compiler.is_compiling():
|
||||
# Logging not compatible with Dynamo tracing
|
||||
# (this code is exposed when torch wrapping is disabled)
|
||||
logger.warning_once(
|
||||
"Priority not set for op %s, using native implementation.",
|
||||
self.name,
|
||||
)
|
||||
return self.impls["native"]
|
||||
|
||||
for impl in self._priority_impls:
|
||||
if not impl.supported:
|
||||
raise ValueError(
|
||||
f"Implementation {impl.provider} for op {self.name} not supported. "
|
||||
f"All implementations in priority list must be supported."
|
||||
)
|
||||
if impl.supports_args(*args, **kwargs):
|
||||
return impl
|
||||
|
||||
if not torch.compiler.is_compiling():
|
||||
logger.debug(
|
||||
"Skipping provider %s because it does not support "
|
||||
"%s with args=%s kwargs=%s",
|
||||
impl.provider,
|
||||
self.name,
|
||||
lazy(lambda: tensors_str_no_data(args)),
|
||||
lazy(lambda: tensors_str_no_data(kwargs)),
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
"Priority set incorrectly: the last implementation must "
|
||||
"support all args (can be native). This is likely an internal bug"
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if not _ENABLE_TORCH_WRAP:
|
||||
return self._inner_call(*args, **kwargs)
|
||||
|
||||
return self.torch_op(*args, **kwargs)
|
||||
|
||||
def get_priority(self) -> list[str]:
|
||||
"""Get the current dispatch priority for implementations for this op."""
|
||||
return [p.provider for p in self._priority_impls]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_priority(self, priority: list[str]):
|
||||
"""
|
||||
Context manager to set the dispatch priority for implementations for this op.
|
||||
"""
|
||||
assert all(p in self.impls for p in priority), (
|
||||
"All providers in priority must be registered implementations."
|
||||
)
|
||||
|
||||
def filter_priority_impls(p_list: list[str]) -> list[IrOpImpl]:
|
||||
filtered_impls = []
|
||||
for p in p_list:
|
||||
impl = self.impls[p]
|
||||
if not impl.supported:
|
||||
# Skip unsupported implementations
|
||||
continue
|
||||
|
||||
filtered_impls.append(impl)
|
||||
|
||||
# If all args are supported, skip other implementations
|
||||
if impl.supports_all_args:
|
||||
return filtered_impls
|
||||
|
||||
logger.warning_once(
|
||||
"Op %s: No implementation in priority list supports all args, "
|
||||
"execution fallback to native is possible. To silence this warning, "
|
||||
"explicitly add 'native' to the end of the priority list",
|
||||
self.name,
|
||||
)
|
||||
filtered_impls.append(self.impls["native"])
|
||||
return filtered_impls
|
||||
|
||||
# Temporarily set priority
|
||||
old_priority_impls = self._priority_impls
|
||||
try:
|
||||
self._priority_impls = filter_priority_impls(priority)
|
||||
yield
|
||||
finally:
|
||||
self._priority_impls = old_priority_impls
|
||||
|
||||
def supported_providers(self) -> list[str]:
|
||||
return [p.provider for p in self.impls.values() if p.supported]
|
||||
|
||||
|
||||
class IrOpImpl:
|
||||
def __init__(
|
||||
self,
|
||||
op: IrOp,
|
||||
provider: str,
|
||||
impl_fn: Callable,
|
||||
supported: bool,
|
||||
supports_args: Callable[..., bool] | None,
|
||||
):
|
||||
assert provider not in op.impls, (
|
||||
f"Implementation for provider {provider} already registered."
|
||||
)
|
||||
# Native also uses this path, so we allow it here.
|
||||
assert provider == "native" or provider not in RESERVED_PROVIDERS
|
||||
|
||||
# Enforce the exact same schema as the native implementation.
|
||||
# This takes care of names, types, and defaults.
|
||||
schema = infer_schema(impl_fn, mutates_args=[])
|
||||
if schema != op._schema_str:
|
||||
raise ValueError(
|
||||
f"Implementation for provider {provider} has schema '{schema}' which "
|
||||
f"does not match native schema '{op._schema_str}' for op {op.name}."
|
||||
)
|
||||
|
||||
if supports_args is not None:
|
||||
if not callable(supports_args):
|
||||
raise ValueError(
|
||||
f"supports_args for provider {provider} must be a callable"
|
||||
)
|
||||
|
||||
# We also manually validate the supports_args signature.
|
||||
# Matching signatures allow faster dispatch on the hotpath.
|
||||
|
||||
# Check that supports_args does not have keyword-only parameters
|
||||
supports_args_signature = inspect.signature(supports_args)
|
||||
params = supports_args_signature.parameters
|
||||
if any(p.kind == inspect.Parameter.KEYWORD_ONLY for p in params.values()):
|
||||
raise ValueError(
|
||||
f"supports_args for provider {provider} "
|
||||
f"cannot have keyword-only parameters"
|
||||
)
|
||||
|
||||
# Check that supports_args has the same total number of parameters
|
||||
op_params = op._py_signature.parameters
|
||||
if len(params) != len(op_params):
|
||||
raise ValueError(
|
||||
f"supports_args for provider {provider} must have the same number "
|
||||
f"of parameters ({len(params)}) as the native implementation "
|
||||
f"({len(op_params)})"
|
||||
)
|
||||
|
||||
# Check that names and defaults match for supports_args
|
||||
for p, op_p in zip(params.values(), op_params.values()):
|
||||
if p.name != op_p.name:
|
||||
raise ValueError(
|
||||
f"supports_args for provider {provider} has parameter "
|
||||
f"'{p.name}' which does not match native parameter "
|
||||
f"'{op_p.name}'"
|
||||
)
|
||||
if p.default != op_p.default:
|
||||
raise ValueError(
|
||||
f"supports_args for provider {provider} has parameter "
|
||||
f"'{p.name}' with default {p.default} which does not match "
|
||||
f"native default {op_p.default}'"
|
||||
)
|
||||
|
||||
self.op = op
|
||||
self.provider = provider
|
||||
self.impl_fn = impl_fn
|
||||
self.supported = supported
|
||||
self._supports_args = supports_args
|
||||
|
||||
@property
|
||||
def supports_all_args(self) -> bool:
|
||||
"""Check if this implementation supports all args unconditionally."""
|
||||
return self._supports_args is None
|
||||
|
||||
def supports_args(self, *args, **kwargs) -> bool:
|
||||
if self._supports_args is None:
|
||||
return True
|
||||
|
||||
return self._supports_args(*args, **kwargs)
|
||||
|
||||
@weak_cache
|
||||
def uuid(self):
|
||||
"""
|
||||
Compile-time hash to uniquely determine whether the implementation has changed.
|
||||
Used by vllm-compile hash mechanism and torch.compile lowering pass uuid to
|
||||
control the vLLM compile cache and AOTAutograd/Inductor caches respectively.
|
||||
|
||||
Source file contents do not change so we cache uuid.
|
||||
TODO(luka): Cache the file hash as multiple impls are likely in the same file.
|
||||
"""
|
||||
sources = [Path(inspect.getfile(self.impl_fn))]
|
||||
return hash_source(*sources)
|
||||
5
vllm/ir/ops/__init__.py
Normal file
5
vllm/ir/ops/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .layernorm import rms_norm
|
||||
|
||||
__all__ = ["rms_norm"]
|
||||
22
vllm/ir/ops/layernorm.py
Normal file
22
vllm/ir/ops/layernorm.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..op import register_op
|
||||
|
||||
|
||||
@register_op
|
||||
def rms_norm(
|
||||
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
|
||||
) -> Tensor:
|
||||
"""Weighted root-mean-square layer normalization"""
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
x_var = x if variance_size is None else x[..., :variance_size]
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + epsilon)
|
||||
x = x.to(orig_dtype)
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
return x
|
||||
61
vllm/ir/util.py
Normal file
61
vllm/ir/util.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import types
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def hash_source(*srcs: str | Any) -> str:
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
:return:
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for src in srcs:
|
||||
if src is None:
|
||||
src_str = "None"
|
||||
elif isinstance(src, str):
|
||||
src_str = src
|
||||
elif isinstance(src, Path):
|
||||
src_str = src.read_text()
|
||||
elif isinstance(src, (types.FunctionType, type)):
|
||||
src_str = inspect.getsource(src)
|
||||
else:
|
||||
# object instance
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def weak_lru_cache(maxsize: int | None = 128, typed: bool = False):
|
||||
"""
|
||||
LRU Cache decorator that keeps a weak reference to 'self'.
|
||||
This avoids memory leakage, which happens when functools.lru_cache
|
||||
stores a reference to self in the global cache.
|
||||
|
||||
Taken from: https://stackoverflow.com/a/68052994/5082708
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
@functools.lru_cache(maxsize, typed)
|
||||
def _func(_self, *args, **kwargs):
|
||||
return func(_self(), *args, **kwargs)
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(self, *args, **kwargs):
|
||||
return _func(weakref.ref(self), *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def weak_cache(user_function, /):
|
||||
"""Simple weak equivalent to functools.cache"""
|
||||
return weak_lru_cache(maxsize=None)(user_function)
|
||||
@@ -1,3 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Kernel implementations for vLLM."""
|
||||
|
||||
from . import aiter_ops, oink_ops, vllm_c, xpu_ops
|
||||
|
||||
__all__ = ["vllm_c", "aiter_ops", "oink_ops", "xpu_ops"]
|
||||
|
||||
79
vllm/kernels/aiter_ops.py
Normal file
79
vllm/kernels/aiter_ops.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.library import Library
|
||||
|
||||
from vllm import ir
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
current_platform.import_kernels()
|
||||
|
||||
|
||||
def is_aiter_found() -> bool:
|
||||
from importlib.util import find_spec
|
||||
|
||||
return find_spec("aiter") is not None
|
||||
|
||||
|
||||
aiter_lib = Library("vllm_aiter", "FRAGMENT")
|
||||
"""
|
||||
This library holds torch custom ops for wrapped AITER ops.
|
||||
Many AITER ops want to remain invisible to torch.compile even after lowering.
|
||||
They are thus wrapped into torch custom ops inside the IR op implementations.
|
||||
"""
|
||||
|
||||
direct_register_aiter_op = functools.partial(
|
||||
direct_register_custom_op, target_lib=aiter_lib
|
||||
)
|
||||
"""Syntactic sugar for registering AITER custom ops."""
|
||||
|
||||
AITER_SUPPORTED = is_aiter_found()
|
||||
"""Most kernels in this file are supported if AITER is installed."""
|
||||
|
||||
rms_no_var_16bit_only = (
|
||||
lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||
and x.dtype
|
||||
in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override."""
|
||||
|
||||
|
||||
@ir.ops.rms_norm.register_impl(
|
||||
"aiter", supports_args=rms_no_var_16bit_only, supported=AITER_SUPPORTED
|
||||
)
|
||||
def rms_norm(
|
||||
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
|
||||
) -> Tensor:
|
||||
assert variance_size is None
|
||||
assert x.dtype in (torch.float16, torch.bfloat16)
|
||||
if weight is None:
|
||||
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
|
||||
return torch.ops.vllm_aiter.rms_norm(x, weight, epsilon)
|
||||
|
||||
|
||||
def _rms_norm_impl(x: Tensor, weight: Tensor, variance_epsilon: float) -> Tensor:
|
||||
from aiter import rms_norm
|
||||
|
||||
if x.dim() > 2:
|
||||
x_original_shape = x.shape
|
||||
x = x.reshape(-1, x_original_shape[-1])
|
||||
x = rms_norm(x, weight, variance_epsilon)
|
||||
return x.reshape(x_original_shape)
|
||||
|
||||
return rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
|
||||
def _rms_norm_fake(x: Tensor, weight: Tensor, variance_epsilon: float) -> Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
direct_register_aiter_op(
|
||||
op_name="rms_norm", op_func=_rms_norm_impl, fake_impl=_rms_norm_fake
|
||||
)
|
||||
77
vllm/kernels/oink_ops.py
Normal file
77
vllm/kernels/oink_ops.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm import ir
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
OINK_AVAILABLE = current_platform.has_device_capability(100) and hasattr(
|
||||
torch.ops, "oink"
|
||||
)
|
||||
|
||||
|
||||
def has_oink_op(name: str) -> bool:
|
||||
"""Check if a specific oink op is registered."""
|
||||
return OINK_AVAILABLE and hasattr(torch.ops.oink, name)
|
||||
|
||||
|
||||
def _can_view_as_2d(x: torch.Tensor) -> bool:
|
||||
"""Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
|
||||
if x.dim() < 2:
|
||||
return False
|
||||
if x.dim() == 2:
|
||||
return True
|
||||
# For a view(-1, N) to be valid, all leading dims must be contiguous with
|
||||
# respect to each other (size-1 dims are ignored).
|
||||
for dim in range(x.dim() - 1):
|
||||
# Strides for size-1 dims are irrelevant and can be arbitrary.
|
||||
if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
|
||||
dim + 1
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
|
||||
"""Return True if x_2d meets Oink's pointer-path stride constraints."""
|
||||
if x_2d.dim() != 2:
|
||||
return False
|
||||
if x_2d.stride(1) != 1:
|
||||
return False
|
||||
# Match Oink's vectorization constraint: stride(0) divisible by 256b.
|
||||
if x_2d.dtype in (torch.float16, torch.bfloat16):
|
||||
divby = 16
|
||||
elif x_2d.dtype == torch.float32:
|
||||
divby = 8
|
||||
else:
|
||||
return False
|
||||
return (x_2d.stride(0) % divby) == 0
|
||||
|
||||
|
||||
oink_rms_supported = (
|
||||
lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||
and weight is not None
|
||||
and x.dim() >= 2
|
||||
and x.dtype == weight.dtype
|
||||
and weight.is_contiguous()
|
||||
and _can_view_as_2d(x)
|
||||
and _is_oink_stride_compatible_2d(x.view(-1, x.shape[-1]))
|
||||
)
|
||||
"""
|
||||
Oink rms only supports 2d-like inputs with contiguous weight
|
||||
and no variance_size override.
|
||||
"""
|
||||
|
||||
|
||||
@ir.ops.rms_norm.register_impl(
|
||||
"oink", supports_args=oink_rms_supported, supported=has_oink_op("rmsnorm")
|
||||
)
|
||||
def rms_norm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor | None,
|
||||
epsilon: float,
|
||||
variance_size: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert variance_size is None
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
return torch.ops.oink.rmsnorm(x_2d, weight, epsilon).view_as(x)
|
||||
30
vllm/kernels/vllm_c.py
Normal file
30
vllm/kernels/vllm_c.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from vllm import ir
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.import_kernels()
|
||||
|
||||
CUDA_ALIKE = current_platform.is_cuda_alike()
|
||||
"""Most kernels in this file are supported on all CUDA-alike platforms."""
|
||||
|
||||
rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||
"""vLLM kernel does not support variance_size parameter."""
|
||||
|
||||
|
||||
@ir.ops.rms_norm.register_impl(
|
||||
"vllm_c", supports_args=rms_no_var_size, supported=CUDA_ALIKE
|
||||
)
|
||||
def rms_norm(
|
||||
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
|
||||
) -> Tensor:
|
||||
if weight is None:
|
||||
# Kernel requires weight tensor, pass ones
|
||||
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
|
||||
assert variance_size is None
|
||||
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
|
||||
torch.ops._C.rms_norm(output, x, weight, epsilon)
|
||||
return output
|
||||
36
vllm/kernels/xpu_ops.py
Normal file
36
vllm/kernels/xpu_ops.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from vllm import ir
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.import_kernels()
|
||||
|
||||
|
||||
def is_xpu_kernels_found() -> bool:
|
||||
from importlib.util import find_spec
|
||||
|
||||
return find_spec("vllm_xpu_kernels") is not None
|
||||
|
||||
|
||||
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
|
||||
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
|
||||
|
||||
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||
|
||||
|
||||
@ir.ops.rms_norm.register_impl(
|
||||
"xpu_kernels", supports_args=rms_no_var, supported=XPU_KERNELS_SUPPORTED
|
||||
)
|
||||
def rms_norm(
|
||||
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
|
||||
) -> Tensor:
|
||||
if weight is None:
|
||||
# Kernel requires weight tensor, pass ones
|
||||
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
|
||||
assert variance_size is None
|
||||
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
|
||||
torch.ops._C.rms_norm(output, x, weight, epsilon)
|
||||
return output
|
||||
@@ -8,6 +8,7 @@ from vllm.logging_utils.access_log_filter import (
|
||||
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
|
||||
from vllm.logging_utils.lazy import lazy
|
||||
from vllm.logging_utils.log_time import logtime
|
||||
from vllm.logging_utils.torch_tensor import tensors_str_no_data
|
||||
|
||||
__all__ = [
|
||||
"NewLineFormatter",
|
||||
@@ -16,4 +17,5 @@ __all__ = [
|
||||
"create_uvicorn_log_config",
|
||||
"lazy",
|
||||
"logtime",
|
||||
"tensors_str_no_data",
|
||||
]
|
||||
|
||||
10
vllm/logging_utils/torch_tensor.py
Normal file
10
vllm/logging_utils/torch_tensor.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
|
||||
def tensors_str_no_data(arg: Any):
|
||||
from torch._tensor_str import printoptions
|
||||
|
||||
with printoptions(threshold=1, edgeitems=0):
|
||||
return str(arg)
|
||||
@@ -6,7 +6,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _oink_ops, envs
|
||||
# Import kernels
|
||||
import vllm.kernels # noqa: F401
|
||||
from vllm import _oink_ops, envs, ir
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@@ -51,23 +53,6 @@ def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
|
||||
return (x_2d.stride(0) % divby) == 0
|
||||
|
||||
|
||||
def rms_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
return rms_norm_batch_invariant(x, weight, variance_epsilon)
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rms_norm(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
@@ -105,23 +90,16 @@ def poly_norm(
|
||||
return out
|
||||
|
||||
|
||||
def dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
|
||||
):
|
||||
def dispatch_rocm_rmsnorm_func(dtype: torch.dtype, use_aiter: bool = False):
|
||||
use_aiter = use_aiter and dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
|
||||
if use_aiter and with_fused_add:
|
||||
return rocm_aiter_ops.rms_norm2d_with_add
|
||||
if use_aiter:
|
||||
return rocm_aiter_ops.rms_norm
|
||||
|
||||
# fall back to CUDA implementation
|
||||
if with_fused_add:
|
||||
return rocm_aiter_ops.rms_norm2d_with_add
|
||||
else:
|
||||
return fused_add_rms_norm
|
||||
return rms_norm
|
||||
|
||||
|
||||
# --8<-- [start:rms_norm]
|
||||
@@ -158,20 +136,14 @@ class RMSNorm(CustomOp):
|
||||
|
||||
if current_platform.is_rocm():
|
||||
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=False,
|
||||
dtype=weight_dtype,
|
||||
use_aiter=aiter_rmsnorm_enabled,
|
||||
)
|
||||
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
|
||||
dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
|
||||
)
|
||||
|
||||
# Optional: enable Oink Blackwell RMSNorm custom-op fast path on
|
||||
# compatible CUDA devices (e.g., SM100) when the external Oink
|
||||
# package is available. This is detected once at construction time
|
||||
# to avoid per-call device queries in the hot path.
|
||||
self._use_oink_rmsnorm = False
|
||||
self._use_oink_fused_add_rmsnorm = False
|
||||
if (
|
||||
not current_platform.is_rocm()
|
||||
@@ -203,7 +175,6 @@ class RMSNorm(CustomOp):
|
||||
try:
|
||||
device_index = torch.accelerator.current_device_index()
|
||||
if _oink_ops.is_oink_available_for_device(device_index):
|
||||
self._use_oink_rmsnorm = True
|
||||
self._use_oink_fused_add_rmsnorm = (
|
||||
_oink_ops.has_fused_add_rms_norm()
|
||||
)
|
||||
@@ -215,7 +186,6 @@ class RMSNorm(CustomOp):
|
||||
"RMSNorm; falling back to vLLM RMSNorm. Error: %s",
|
||||
e,
|
||||
)
|
||||
self._use_oink_rmsnorm = False
|
||||
self._use_oink_fused_add_rmsnorm = False
|
||||
|
||||
@staticmethod
|
||||
@@ -270,6 +240,10 @@ class RMSNorm(CustomOp):
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
if residual is None:
|
||||
return ir.ops.rms_norm(
|
||||
x, self.weight.data, self.variance_epsilon, self.variance_size_override
|
||||
)
|
||||
|
||||
return self.forward_static(
|
||||
x,
|
||||
@@ -286,35 +260,14 @@ class RMSNorm(CustomOp):
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None and not envs.VLLM_BATCH_INVARIANT:
|
||||
return ir.ops.rms_norm(
|
||||
x, self.weight.data, self.variance_epsilon, self.variance_size_override
|
||||
)
|
||||
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
# Optional Oink SM100 fast path (no residual). This path is
|
||||
# torch.compile-friendly via torch.ops.oink.rmsnorm and preserves
|
||||
# 2D layouts (including padded rows) when using the Oink
|
||||
# pointer-based kernel.
|
||||
if (
|
||||
residual is None
|
||||
and getattr(self, "_use_oink_rmsnorm", False)
|
||||
and x.is_cuda
|
||||
and x.dim() >= 2
|
||||
and self.has_weight
|
||||
and not envs.VLLM_BATCH_INVARIANT
|
||||
and self.weight.data.dtype == x.dtype
|
||||
and self.weight.data.is_contiguous()
|
||||
):
|
||||
orig_shape = x.shape
|
||||
hidden_size = orig_shape[-1]
|
||||
if _can_view_as_2d(x):
|
||||
x_2d = x.view(-1, hidden_size)
|
||||
if _is_oink_stride_compatible_2d(x_2d):
|
||||
y_2d = _oink_ops.rmsnorm(
|
||||
x_2d,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return y_2d.view(orig_shape)
|
||||
|
||||
# Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
|
||||
# This mirrors vLLM's fused_add_rms_norm semantics by mutating both
|
||||
# `x` (normalized output) and `residual` (residual-out buffer).
|
||||
@@ -356,29 +309,34 @@ class RMSNorm(CustomOp):
|
||||
)
|
||||
return x, residual
|
||||
|
||||
add_residual = residual is not None
|
||||
if add_residual:
|
||||
if residual is not None:
|
||||
return fused_add_rms_norm(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
else:
|
||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||
assert envs.VLLM_BATCH_INVARIANT
|
||||
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None and not envs.VLLM_BATCH_INVARIANT:
|
||||
return ir.ops.rms_norm(
|
||||
x, self.weight.data, self.variance_epsilon, self.variance_size_override
|
||||
)
|
||||
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
add_residual = residual is not None
|
||||
if add_residual:
|
||||
if residual is not None:
|
||||
return self.rocm_norm_func_with_add(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
else:
|
||||
return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
|
||||
assert envs.VLLM_BATCH_INVARIANT
|
||||
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
|
||||
@@ -30,6 +30,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
from vllm.v1.attention.selector import AttentionSelectorConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
@@ -550,6 +551,26 @@ class CudaPlatformBase(Platform):
|
||||
def use_custom_op_collectives(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConfig:
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
|
||||
# Native used by default when compiling,
|
||||
# use vllm_c kernels where available when no codegen
|
||||
cc = vllm_config.compilation_config
|
||||
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
|
||||
default = ["native"] if using_inductor else ["vllm_c", "native"]
|
||||
|
||||
# Use oink if enabled for rms_norm
|
||||
# TODO(Laurawly/luka): remove this env var,
|
||||
# users can just use IR op priority directly
|
||||
rms_norm = default
|
||||
if envs.VLLM_USE_OINK_OPS:
|
||||
rms_norm = ["oink"] + default
|
||||
|
||||
return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
|
||||
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -931,6 +932,16 @@ class Platform:
|
||||
"num_compute_units is not implemented for the current platform."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_default_ir_op_priority(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> "IrOpPriorityConfig":
|
||||
"""Get the default IR op priority for the current platform."""
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
|
||||
# Native always used by default. Platforms can override this behavior.
|
||||
return IrOpPriorityConfig.with_default(["native"])
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@@ -19,6 +19,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
from vllm.v1.attention.selector import AttentionSelectorConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -903,3 +904,32 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def use_custom_op_collectives(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_default_ir_op_priority(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> "IrOpPriorityConfig":
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
|
||||
# Native used by default when compiling,
|
||||
# use vllm_c kernels where available when no codegen
|
||||
# TODO(luka/TJ) use aiter, vllm_c, native by default on ROCm
|
||||
cc = vllm_config.compilation_config
|
||||
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
|
||||
default = ["native"] if using_inductor else ["vllm_c", "native"]
|
||||
|
||||
# This (mostly) preserves previous CustomOp behavior
|
||||
# Necessary on ROCm because it's common that users
|
||||
# enable rms_norm to use the aiter kernel.
|
||||
# TODO(luka/TJ) remove env vars completely
|
||||
if (
|
||||
cc.is_custom_op_enabled("rms_norm")
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
):
|
||||
rms_norm = ["aiter"] + default
|
||||
else:
|
||||
rms_norm = default
|
||||
|
||||
return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
|
||||
|
||||
@@ -21,6 +21,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
from vllm.v1.attention.selector import AttentionSelectorConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
@@ -257,6 +258,21 @@ class XPUPlatform(Platform):
|
||||
)
|
||||
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def get_default_ir_op_priority(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> "IrOpPriorityConfig":
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.config.kernel import IrOpPriorityConfig
|
||||
|
||||
# Native used by default when compiling,
|
||||
# use fused kernels where available when no codegen
|
||||
cc = vllm_config.compilation_config
|
||||
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
|
||||
default = ["native"] if using_inductor else ["xpu_kernels", "native"]
|
||||
|
||||
return IrOpPriorityConfig.with_default(default)
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return torch.xpu.device_count()
|
||||
|
||||
Reference in New Issue
Block a user