[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> Signed-off-by: chzhang <chaojun.zhang@intel.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com> Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com> Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com> Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
This commit is contained in:
@@ -2,6 +2,16 @@ group: Kernels
|
|||||||
depends_on:
|
depends_on:
|
||||||
- image-build
|
- image-build
|
||||||
steps:
|
steps:
|
||||||
|
- label: vLLM IR Tests
|
||||||
|
timeout_in_minutes: 10
|
||||||
|
working_dir: "/vllm-workspace/"
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/ir
|
||||||
|
- vllm/kernels
|
||||||
|
commands:
|
||||||
|
- pytest -v -s tests/ir
|
||||||
|
- pytest -v -s tests/kernels/ir
|
||||||
|
|
||||||
- label: Kernels Core Operation Test
|
- label: Kernels Core Operation Test
|
||||||
timeout_in_minutes: 75
|
timeout_in_minutes: 75
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
|
|||||||
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@@ -13,6 +13,9 @@
|
|||||||
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
|
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
|
||||||
/vllm/model_executor/model_loader @22quinn
|
/vllm/model_executor/model_loader @22quinn
|
||||||
/vllm/model_executor/layers/batch_invariant.py @yewentao256
|
/vllm/model_executor/layers/batch_invariant.py @yewentao256
|
||||||
|
/vllm/ir @ProExpertProg
|
||||||
|
/vllm/kernels/ @ProExpertProg @tjtanaa
|
||||||
|
/vllm/kernels/helion @ProExpertProg @zou3519
|
||||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
|
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
|
||||||
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
|
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
|
||||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||||
@@ -74,6 +77,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
|||||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
|
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
|
||||||
/tests/evals @mgoin @vadiklyutiy
|
/tests/evals @mgoin @vadiklyutiy
|
||||||
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
|
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
|
||||||
|
/tests/kernels/ir @ProExpertProg @tjtanaa
|
||||||
/tests/models @DarkLight1337 @ywang96
|
/tests/models @DarkLight1337 @ywang96
|
||||||
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
|
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||||
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety
|
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from copy import deepcopy
|
|||||||
|
|
||||||
import depyf
|
import depyf
|
||||||
from torch import fx
|
from torch import fx
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload, OpOverloadPacket
|
||||||
from torch.fx._utils import lazy_format_graph_code
|
from torch.fx._utils import lazy_format_graph_code
|
||||||
|
|
||||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||||
@@ -90,7 +90,9 @@ class TestBackend:
|
|||||||
# assign by reference, will reflect the final state of the graph
|
# assign by reference, will reflect the final state of the graph
|
||||||
self.final_graph = graph
|
self.final_graph = graph
|
||||||
|
|
||||||
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
|
def check_before_ops(
|
||||||
|
self, ops: Sequence[OpOverload | OpOverloadPacket], fully_replaced=True
|
||||||
|
):
|
||||||
for op in ops:
|
for op in ops:
|
||||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||||
@@ -99,13 +101,19 @@ class TestBackend:
|
|||||||
if fully_replaced:
|
if fully_replaced:
|
||||||
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
|
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
|
||||||
|
|
||||||
def check_after_ops(self, ops: Sequence[OpOverload]):
|
def check_after_ops(self, ops: Sequence[OpOverload | OpOverloadPacket]):
|
||||||
for op in ops:
|
for op in ops:
|
||||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||||
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
|
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
|
||||||
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
|
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
|
||||||
|
|
||||||
def op_count(self, op: OpOverload, before=False) -> int:
|
def op_count(self, op: OpOverload | OpOverloadPacket, before=False) -> int:
|
||||||
graph = self.graph_pre_pass if before else self.graph_post_pass
|
graph = self.graph_pre_pass if before else self.graph_post_pass
|
||||||
return len(list(find_op_nodes(op, graph)))
|
return len(list(find_op_nodes(op, graph)))
|
||||||
|
|
||||||
|
def print_graphs(self):
|
||||||
|
print("=== Graph before custom passes ===")
|
||||||
|
print(self.graph_pre_pass.python_code(root_module="self", verbose=True).src)
|
||||||
|
print("=== Graph after custom passes ===")
|
||||||
|
print(self.graph_post_pass.python_code(root_module="self", verbose=True).src)
|
||||||
|
|||||||
@@ -99,6 +99,8 @@ def test_tp1_fp8_fusions(
|
|||||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||||
model_kwargs["load_format"] = "dummy"
|
model_kwargs["load_format"] = "dummy"
|
||||||
model_kwargs["max_model_len"] = 1024
|
model_kwargs["max_model_len"] = 1024
|
||||||
|
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||||
|
|
||||||
compilation_config = dict(
|
compilation_config = dict(
|
||||||
use_inductor_graph_partition=inductor_graph_partition,
|
use_inductor_graph_partition=inductor_graph_partition,
|
||||||
custom_ops=custom_ops.split(","),
|
custom_ops=custom_ops.split(","),
|
||||||
@@ -166,6 +168,7 @@ def test_tp1_fp4_fusions(
|
|||||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||||
model_kwargs["load_format"] = "dummy"
|
model_kwargs["load_format"] = "dummy"
|
||||||
model_kwargs["max_model_len"] = 1024
|
model_kwargs["max_model_len"] = 1024
|
||||||
|
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||||
|
|
||||||
compilation_config = dict(
|
compilation_config = dict(
|
||||||
use_inductor_graph_partition=inductor_graph_partition,
|
use_inductor_graph_partition=inductor_graph_partition,
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ def test_tp2_ar_rms_fp8_fusions(
|
|||||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||||
model_kwargs["load_format"] = "dummy"
|
model_kwargs["load_format"] = "dummy"
|
||||||
model_kwargs["max_model_len"] = 1024
|
model_kwargs["max_model_len"] = 1024
|
||||||
|
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||||
|
|
||||||
compilation_config = dict(
|
compilation_config = dict(
|
||||||
use_inductor_graph_partition=inductor_graph_partition,
|
use_inductor_graph_partition=inductor_graph_partition,
|
||||||
@@ -128,6 +129,7 @@ def test_tp2_ar_rms_fp4_fusions(
|
|||||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||||
model_kwargs["load_format"] = "dummy"
|
model_kwargs["load_format"] = "dummy"
|
||||||
model_kwargs["max_model_len"] = 1024
|
model_kwargs["max_model_len"] = 1024
|
||||||
|
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||||
|
|
||||||
compilation_config = dict(
|
compilation_config = dict(
|
||||||
use_inductor_graph_partition=inductor_graph_partition,
|
use_inductor_graph_partition=inductor_graph_partition,
|
||||||
@@ -182,6 +184,7 @@ def test_tp2_ar_rms_fusions(
|
|||||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||||
model_kwargs["load_format"] = "dummy"
|
model_kwargs["load_format"] = "dummy"
|
||||||
model_kwargs["max_model_len"] = 1024
|
model_kwargs["max_model_len"] = 1024
|
||||||
|
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||||
|
|
||||||
compilation_config = dict(
|
compilation_config = dict(
|
||||||
use_inductor_graph_partition=inductor_graph_partition,
|
use_inductor_graph_partition=inductor_graph_partition,
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ def test_tp2_async_tp_fp8_fusions(
|
|||||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||||
model_kwargs["load_format"] = "dummy"
|
model_kwargs["load_format"] = "dummy"
|
||||||
model_kwargs["max_model_len"] = 1024
|
model_kwargs["max_model_len"] = 1024
|
||||||
|
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||||
|
|
||||||
compilation_config = dict(
|
compilation_config = dict(
|
||||||
use_inductor_graph_partition=inductor_graph_partition,
|
use_inductor_graph_partition=inductor_graph_partition,
|
||||||
@@ -121,6 +122,7 @@ def test_tp2_async_tp_fusions(
|
|||||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||||
model_kwargs["load_format"] = "dummy"
|
model_kwargs["load_format"] = "dummy"
|
||||||
model_kwargs["max_model_len"] = 1024
|
model_kwargs["max_model_len"] = 1024
|
||||||
|
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
|
||||||
|
|
||||||
compilation_config = dict(
|
compilation_config = dict(
|
||||||
use_inductor_graph_partition=inductor_graph_partition,
|
use_inductor_graph_partition=inductor_graph_partition,
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from tests.compile.backend import TestBackend
|
|||||||
from tests.utils import TestFP8Layer, multi_gpu_test
|
from tests.utils import TestFP8Layer, multi_gpu_test
|
||||||
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||||
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
|
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
|
||||||
from vllm.compilation.passes.fx_utils import find_auto_fn
|
|
||||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||||
@@ -86,13 +85,14 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def ops_in_model(self):
|
def ops_in_model(self):
|
||||||
if RMSNorm.enabled():
|
return (
|
||||||
return [
|
[torch.ops.vllm_ir.rms_norm]
|
||||||
torch.ops._C.rms_norm.default,
|
+ [
|
||||||
torch.ops._C.fused_add_rms_norm.default,
|
torch.ops._C.fused_add_rms_norm.default,
|
||||||
]
|
]
|
||||||
else:
|
if RMSNorm.enabled()
|
||||||
return []
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||||
@@ -321,4 +321,4 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
assert backend.op_count(op, before=False) == 4
|
assert backend.op_count(op, before=False) == 4
|
||||||
|
|
||||||
for op in model.ops_in_model():
|
for op in model.ops_in_model():
|
||||||
find_auto_fn(backend.graph_post_pass.nodes, op)
|
assert backend.op_count(op, before=False) > 0
|
||||||
|
|||||||
0
tests/compile/passes/ir/__init__.py
Normal file
0
tests/compile/passes/ir/__init__.py
Normal file
69
tests/compile/passes/ir/test_lowering.py
Normal file
69
tests/compile/passes/ir/test_lowering.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import vllm.kernels # noqa: F401 to register kernels
|
||||||
|
from vllm import ir
|
||||||
|
from vllm.compilation.passes.ir.lowering_pass import (
|
||||||
|
VllmIRLoweringPass,
|
||||||
|
)
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
from vllm.ir import ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from ...backend import TestBackend
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, hidden_size=16, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.weight = torch.ones(hidden_size, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = x + 4.0
|
||||||
|
x2 = ops.rms_norm(x1, self.weight, 1e-5)
|
||||||
|
x3 = x2 * 5.0
|
||||||
|
# no weight
|
||||||
|
x4 = ops.rms_norm(x3, None, 1e-5)
|
||||||
|
x5 = x4 / 2.0
|
||||||
|
# dispatch to native due to variance_size parameter
|
||||||
|
x6 = ops.rms_norm(x5, self.weight, 1e-5, self.hidden_size // 2)
|
||||||
|
return x6 + 3.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("rms_provider", ops.rms_norm.supported_providers())
|
||||||
|
def test_lowering_rms_norm(rms_provider, default_vllm_config):
|
||||||
|
torch.set_default_device(current_platform.device_type)
|
||||||
|
|
||||||
|
lowering_pass = VllmIRLoweringPass(get_current_vllm_config())
|
||||||
|
backend = TestBackend(lowering_pass)
|
||||||
|
backend_unlowered = TestBackend()
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
x = torch.randn(8, 16, dtype=torch.bfloat16)
|
||||||
|
with (
|
||||||
|
ops.rms_norm.set_priority([rms_provider, "native"]),
|
||||||
|
ir.enable_torch_wrap(True),
|
||||||
|
):
|
||||||
|
compiled_model = torch.compile(model, backend=backend, fullgraph=True)
|
||||||
|
compiled_unlowered_model = torch.compile(
|
||||||
|
model, backend=backend_unlowered, fullgraph=True
|
||||||
|
)
|
||||||
|
output = compiled_model(x)
|
||||||
|
output_unlowered = compiled_unlowered_model(x)
|
||||||
|
|
||||||
|
selected = lowering_pass.selected_impls["rms_norm"]
|
||||||
|
assert len(selected) == 3
|
||||||
|
assert selected["rms_norm"] == rms_provider
|
||||||
|
assert selected["rms_norm_1"] == rms_provider
|
||||||
|
assert selected["rms_norm_2"] == "native"
|
||||||
|
|
||||||
|
# Compiled function guards on global value, avoid recompilation
|
||||||
|
with ir.enable_torch_wrap(True):
|
||||||
|
output2 = compiled_model(x)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output_unlowered, output)
|
||||||
|
torch.testing.assert_close(output_unlowered, output2)
|
||||||
@@ -6,6 +6,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.config
|
import vllm.config
|
||||||
|
import vllm.ir.ops
|
||||||
import vllm.plugins
|
import vllm.plugins
|
||||||
from tests.compile.backend import TestBackend
|
from tests.compile.backend import TestBackend
|
||||||
from tests.utils import TestBlockFP8Layer, TestFP8Layer
|
from tests.utils import TestBlockFP8Layer, TestFP8Layer
|
||||||
@@ -51,7 +52,6 @@ from vllm.utils.deep_gemm import (
|
|||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
RMS_OP = torch.ops._C.rms_norm.default
|
|
||||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||||
|
|
||||||
# Kernel and group_shape combinations: (kernel, group_shape)
|
# Kernel and group_shape combinations: (kernel, group_shape)
|
||||||
@@ -246,10 +246,8 @@ class TestModel(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def ops_in_model_before_partial(self):
|
def ops_in_model_before_partial(self):
|
||||||
return (
|
return [torch.ops.vllm_ir.rms_norm] + (
|
||||||
[RMS_OP, RMS_ADD_OP]
|
[RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt]
|
||||||
if self.enable_rms_norm_custom_op
|
|
||||||
else [torch.ops.aten.rsqrt]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -340,7 +338,10 @@ def test_fusion_rmsnorm_quant(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with (
|
||||||
|
vllm.config.set_current_vllm_config(vllm_config),
|
||||||
|
vllm_config.kernel_config.ir_op_priority.set_priority(),
|
||||||
|
):
|
||||||
# Setup device before model creation
|
# Setup device before model creation
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
@@ -370,8 +371,9 @@ def test_fusion_rmsnorm_quant(
|
|||||||
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
||||||
if not enable_rms_norm_custom_op:
|
if not enable_rms_norm_custom_op:
|
||||||
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
||||||
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
# rms_norm is IR, not included
|
||||||
assert n_add_nodes(backend.graph_pre_pass) == 7
|
# 6 = 3x2 (3xRMS_ADD, 2 each)
|
||||||
|
assert n_add_nodes(backend.graph_pre_pass) == 6
|
||||||
assert n_add_nodes(backend.graph_post_pass) == 2
|
assert n_add_nodes(backend.graph_post_pass) == 2
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,11 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from torch._ops import OpOverload, OpOverloadPacket
|
||||||
|
|
||||||
from tests.compile.backend import TestBackend
|
from tests.compile.backend import TestBackend
|
||||||
from vllm.compilation.passes.fusion.matcher_utils import (
|
from vllm.compilation.passes.fusion.matcher_utils import (
|
||||||
FLASHINFER_ROTARY_OP,
|
FLASHINFER_ROTARY_OP,
|
||||||
RMS_OP,
|
|
||||||
ROTARY_OP,
|
ROTARY_OP,
|
||||||
)
|
)
|
||||||
from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
|
from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
|
||||||
@@ -100,13 +100,8 @@ class QKNormRoPETestModel(torch.nn.Module):
|
|||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
def ops_in_model_before(self) -> list[OpOverload | OpOverloadPacket]:
|
||||||
ops = []
|
ops: list[OpOverload | OpOverloadPacket] = [torch.ops.vllm_ir.rms_norm]
|
||||||
if self.enable_rms_norm_custom_op:
|
|
||||||
ops.append(RMS_OP)
|
|
||||||
else:
|
|
||||||
ops.append(RSQRT_OP)
|
|
||||||
|
|
||||||
if self.enable_rope_custom_op:
|
if self.enable_rope_custom_op:
|
||||||
if self.rotary_emb.use_flashinfer:
|
if self.rotary_emb.use_flashinfer:
|
||||||
ops.append(FLASHINFER_ROTARY_OP)
|
ops.append(FLASHINFER_ROTARY_OP)
|
||||||
@@ -116,7 +111,7 @@ class QKNormRoPETestModel(torch.nn.Module):
|
|||||||
ops.append(INDEX_SELECT_OP)
|
ops.append(INDEX_SELECT_OP)
|
||||||
return ops
|
return ops
|
||||||
|
|
||||||
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
|
def ops_in_model_after(self) -> list[OpOverload | OpOverloadPacket]:
|
||||||
return [FUSED_QK_ROPE_OP]
|
return [FUSED_QK_ROPE_OP]
|
||||||
|
|
||||||
|
|
||||||
@@ -166,7 +161,10 @@ def test_qk_norm_rope_fusion(
|
|||||||
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
||||||
T = 5
|
T = 5
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with (
|
||||||
|
set_current_vllm_config(vllm_config),
|
||||||
|
vllm_config.kernel_config.ir_op_priority.set_priority(),
|
||||||
|
):
|
||||||
model = QKNormRoPETestModel(
|
model = QKNormRoPETestModel(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
|
|||||||
@@ -1622,3 +1622,26 @@ def fresh_vllm_cache(monkeypatch, use_fresh_inductor_cache):
|
|||||||
def enable_pickle(monkeypatch):
|
def enable_pickle(monkeypatch):
|
||||||
"""`LLM.apply_model` requires pickling a function."""
|
"""`LLM.apply_model` requires pickling a function."""
|
||||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def disable_log_dedup(monkeypatch):
|
||||||
|
"""
|
||||||
|
Disable log deduplication such that warning_once and info_once always print.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Patch logger._print_warning_once to remove the lru_cache decorator
|
||||||
|
from vllm import logger
|
||||||
|
|
||||||
|
original_print_warning_once = logger._print_warning_once
|
||||||
|
original_print_info_once = logger._print_info_once
|
||||||
|
original_print_debug_once = logger._print_debug_once
|
||||||
|
|
||||||
|
logger._print_warning_once = original_print_warning_once.__wrapped__
|
||||||
|
logger._print_info_once = original_print_info_once.__wrapped__
|
||||||
|
logger._print_debug_once = original_print_debug_once.__wrapped__
|
||||||
|
|
||||||
|
yield
|
||||||
|
logger._print_warning_once = original_print_warning_once
|
||||||
|
logger._print_info_once = original_print_info_once
|
||||||
|
logger._print_debug_once = original_print_debug_once
|
||||||
|
|||||||
@@ -523,3 +523,20 @@ def test_human_readable_model_len():
|
|||||||
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
|
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
|
||||||
with pytest.raises(ArgumentError):
|
with pytest.raises(ArgumentError):
|
||||||
parser.parse_args(["--max-model-len", invalid])
|
parser.parse_args(["--max-model-len", invalid])
|
||||||
|
|
||||||
|
|
||||||
|
def test_ir_op_priority():
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig, KernelConfig
|
||||||
|
|
||||||
|
ir_op_priority = IrOpPriorityConfig(rms_norm=["vllm_c"])
|
||||||
|
cfg1 = EngineArgs(ir_op_priority=ir_op_priority).create_engine_config()
|
||||||
|
cfg2 = EngineArgs(
|
||||||
|
kernel_config=KernelConfig(ir_op_priority=ir_op_priority)
|
||||||
|
).create_engine_config()
|
||||||
|
assert cfg1.kernel_config.ir_op_priority == cfg2.kernel_config.ir_op_priority
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="rms_norm"):
|
||||||
|
_ = EngineArgs(
|
||||||
|
ir_op_priority=ir_op_priority,
|
||||||
|
kernel_config=KernelConfig(ir_op_priority=ir_op_priority),
|
||||||
|
).create_engine_config()
|
||||||
|
|||||||
497
tests/ir/test_op.py
Normal file
497
tests/ir/test_op.py
Normal file
@@ -0,0 +1,497 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import importlib.util
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch import fx
|
||||||
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
|
||||||
|
import vllm.ir.op
|
||||||
|
from vllm.ir.op import RESERVED_PROVIDERS, IrOp, IrOpImpl
|
||||||
|
|
||||||
|
# This should not exist
|
||||||
|
assert "_custom_add" not in IrOp.registry
|
||||||
|
|
||||||
|
|
||||||
|
class CustomError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@vllm.ir.register_op
|
||||||
|
def _custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
|
def test_registration_overloads():
|
||||||
|
assert all(
|
||||||
|
n not in IrOp.registry for n in ["_custom_sub", "_custom_mul", "_custom_div"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calling with decorator
|
||||||
|
@vllm.ir.register_op()
|
||||||
|
def _custom_sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x - y
|
||||||
|
|
||||||
|
assert _custom_sub.name == "_custom_sub"
|
||||||
|
assert _custom_sub is IrOp.registry["_custom_sub"]
|
||||||
|
|
||||||
|
# Custom name
|
||||||
|
@vllm.ir.register_op(name="_custom_mul")
|
||||||
|
def custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
assert custom_mul.name == "_custom_mul"
|
||||||
|
assert custom_mul is IrOp.registry["_custom_mul"]
|
||||||
|
|
||||||
|
# Direct construction does not register directly
|
||||||
|
def _custom_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x / y
|
||||||
|
|
||||||
|
custom_div = IrOp("_custom_div", _custom_div)
|
||||||
|
assert custom_div.name == "_custom_div"
|
||||||
|
assert "_custom_div" not in IrOp.registry
|
||||||
|
|
||||||
|
# Duplicate op registration not allowed
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
|
||||||
|
@vllm.ir.register_op
|
||||||
|
def _custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x * y - 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_kw_only_args():
|
||||||
|
# kw-only args not supported
|
||||||
|
with pytest.raises(ValueError, match="keyword-only arguments"):
|
||||||
|
|
||||||
|
@vllm.ir.register_op
|
||||||
|
def _custom_kwarg_op(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, *, kwarg: int = 0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return x + y + kwarg
|
||||||
|
|
||||||
|
assert "_custom_kwarg_op" not in IrOp.registry
|
||||||
|
|
||||||
|
|
||||||
|
class TestIrOpCustomAdd:
|
||||||
|
# Registration invariants
|
||||||
|
def test_decorated_object(self):
|
||||||
|
"""Make sure that referring directly to an op is correct"""
|
||||||
|
assert isinstance(_custom_add, IrOp)
|
||||||
|
assert "_custom_add" in IrOp.registry
|
||||||
|
assert _custom_add is IrOp.registry["_custom_add"]
|
||||||
|
|
||||||
|
def test_torch_op_is_registered(self):
|
||||||
|
assert hasattr(torch.ops.vllm_ir, "_custom_add")
|
||||||
|
assert callable(torch.ops.vllm_ir._custom_add.default)
|
||||||
|
|
||||||
|
# Semantic correctness
|
||||||
|
def test_semantics_match_native(self):
|
||||||
|
x = torch.randn(4, 5)
|
||||||
|
y = torch.randn(4, 5)
|
||||||
|
|
||||||
|
# Calls native by default
|
||||||
|
out = _custom_add(x, y)
|
||||||
|
ref = x + y
|
||||||
|
|
||||||
|
torch.testing.assert_close(out, ref)
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Implementation registration
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
def test_register_impl_is_non_intrusive(self):
|
||||||
|
@_custom_add.register_impl("dummy_provider")
|
||||||
|
def dummy_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + y + 123
|
||||||
|
|
||||||
|
assert "dummy_provider" in _custom_add.impls
|
||||||
|
assert isinstance(_custom_add.impls["dummy_provider"], IrOpImpl)
|
||||||
|
|
||||||
|
x = torch.ones(2, 2)
|
||||||
|
y = torch.ones(2, 2)
|
||||||
|
|
||||||
|
# Native semantics must still hold
|
||||||
|
torch.testing.assert_close(_custom_add(x, y), x + y)
|
||||||
|
|
||||||
|
def test_schema_contains_tensor_signature(self):
|
||||||
|
schema = _custom_add._schema_str
|
||||||
|
|
||||||
|
assert "Tensor" in schema
|
||||||
|
assert "-> Tensor" in schema
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# FX visibility
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enable_torch_wrap", [True, False])
|
||||||
|
@pytest.mark.parametrize("symbolic_trace", [True, False])
|
||||||
|
def test_trace_sees_single_custom_op(
|
||||||
|
self, symbolic_trace: bool, enable_torch_wrap: bool
|
||||||
|
):
|
||||||
|
def fn(x, y):
|
||||||
|
return _custom_add(x, y)
|
||||||
|
|
||||||
|
def find_fn(target: Any, gm: fx.GraphModule):
|
||||||
|
return gm.graph.find_nodes(op="call_function", target=target)
|
||||||
|
|
||||||
|
with pytest.raises(CustomError), vllm.ir.enable_torch_wrap(enable_torch_wrap):
|
||||||
|
if symbolic_trace:
|
||||||
|
gm = torch.fx.symbolic_trace(fn)
|
||||||
|
else:
|
||||||
|
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
|
||||||
|
|
||||||
|
x1, y1 = torch.rand(5, 4), torch.rand(5, 4)
|
||||||
|
out_fx = gm(x1, y1)
|
||||||
|
out_eager = fn(x1, y1)
|
||||||
|
|
||||||
|
# raise error to check enable_torch_wrap context restored correctly
|
||||||
|
raise CustomError
|
||||||
|
|
||||||
|
# check behavior matches eager in all cases
|
||||||
|
torch.testing.assert_close(out_fx, out_eager)
|
||||||
|
|
||||||
|
# check that IR nodes only appear if enable_torch_wrap=True
|
||||||
|
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
|
||||||
|
if enable_torch_wrap:
|
||||||
|
assert len(ir_nodes) == 1, gm.code
|
||||||
|
else:
|
||||||
|
assert len(ir_nodes) == 0, gm.code
|
||||||
|
|
||||||
|
# with torch wrapping enabled (default), IR nodes appear
|
||||||
|
if symbolic_trace:
|
||||||
|
gm = torch.fx.symbolic_trace(fn)
|
||||||
|
else:
|
||||||
|
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
|
||||||
|
|
||||||
|
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
|
||||||
|
assert len(ir_nodes) == 1, gm.code
|
||||||
|
|
||||||
|
|
||||||
|
@_custom_add.register_impl("impl_a")
|
||||||
|
def impl_a(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + y + 10
|
||||||
|
|
||||||
|
|
||||||
|
@_custom_add.register_impl("impl_b")
|
||||||
|
def impl_b(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + y + 20
|
||||||
|
|
||||||
|
|
||||||
|
@_custom_add.register_impl("impl_even", supports_args=lambda x, y: x.size(1) % 2 == 0)
|
||||||
|
def impl_even(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + y + 50
|
||||||
|
|
||||||
|
|
||||||
|
class TestIrOpImplDispatch:
|
||||||
|
def test_register_impl(self):
|
||||||
|
assert "impl_a" in _custom_add.impls
|
||||||
|
impl = _custom_add.impls["impl_a"]
|
||||||
|
|
||||||
|
assert impl is impl_a
|
||||||
|
assert impl.op is _custom_add
|
||||||
|
assert impl.provider == "impl_a"
|
||||||
|
assert callable(impl.impl_fn)
|
||||||
|
|
||||||
|
# Test duplicate registration rejected
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
|
||||||
|
@_custom_add.register_impl("impl_a")
|
||||||
|
def impl_a_dup(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + y + 30
|
||||||
|
|
||||||
|
# Check the original impl is still intact
|
||||||
|
assert _custom_add.impls["impl_a"] is impl_a
|
||||||
|
|
||||||
|
# Check support all args
|
||||||
|
assert impl_a.supports_all_args
|
||||||
|
assert impl_b.supports_all_args
|
||||||
|
assert not impl_even.supports_all_args
|
||||||
|
|
||||||
|
def test_reserved_provider_rejected(self):
|
||||||
|
for provider in RESERVED_PROVIDERS:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
|
||||||
|
@_custom_add.register_impl(provider)
|
||||||
|
def bad_impl(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
def test_set_priority_scoped(self):
|
||||||
|
assert _custom_add.get_priority() == []
|
||||||
|
|
||||||
|
with _custom_add.set_priority(["impl_even", "impl_b"]):
|
||||||
|
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
|
||||||
|
|
||||||
|
# Check nesting
|
||||||
|
with _custom_add.set_priority(["impl_b"]):
|
||||||
|
assert _custom_add.get_priority() == ["impl_b"]
|
||||||
|
|
||||||
|
# Restored
|
||||||
|
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
|
||||||
|
|
||||||
|
# Check that exception restores priority
|
||||||
|
with pytest.raises(CustomError), _custom_add.set_priority(["impl_a"]):
|
||||||
|
assert _custom_add.get_priority() == ["impl_a"]
|
||||||
|
raise CustomError
|
||||||
|
|
||||||
|
# Restored again
|
||||||
|
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
|
||||||
|
|
||||||
|
# Restored to empty
|
||||||
|
assert _custom_add.get_priority() == []
|
||||||
|
|
||||||
|
def test_dispatch_priority_order(self):
|
||||||
|
x = torch.tensor(1, dtype=torch.int32)
|
||||||
|
y = torch.tensor(2, dtype=torch.int32)
|
||||||
|
|
||||||
|
with _custom_add.set_priority(["impl_b", "impl_a"]):
|
||||||
|
assert _custom_add.dispatch(x, y) is impl_b
|
||||||
|
out1 = _custom_add(x, y)
|
||||||
|
out2 = torch.ops.vllm_ir._custom_add(x, y)
|
||||||
|
|
||||||
|
with _custom_add.set_priority(["impl_a"]):
|
||||||
|
assert _custom_add.dispatch(x, y) is impl_a
|
||||||
|
out3 = _custom_add(x, y)
|
||||||
|
out4 = torch.ops.vllm_ir._custom_add(x, y)
|
||||||
|
|
||||||
|
# impl_b
|
||||||
|
assert out1.item() == 1 + 2 + 20
|
||||||
|
assert out2.item() == 1 + 2 + 20
|
||||||
|
# impl_a
|
||||||
|
assert out3.item() == 1 + 2 + 10
|
||||||
|
assert out4.item() == 1 + 2 + 10
|
||||||
|
|
||||||
|
def test_unsupported_impl_filtered(self):
|
||||||
|
@_custom_add.register_impl("unsupported", supported=False)
|
||||||
|
def impl_bad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + y + 999
|
||||||
|
|
||||||
|
x = torch.tensor(1, dtype=torch.int32)
|
||||||
|
y = torch.tensor(2, dtype=torch.int32)
|
||||||
|
|
||||||
|
with _custom_add.set_priority(["unsupported", "impl_a"]):
|
||||||
|
assert _custom_add.get_priority() == ["impl_a"]
|
||||||
|
out = _custom_add(x, y)
|
||||||
|
|
||||||
|
# impl_bad skipped → impl_a
|
||||||
|
assert out.item() == 1 + 2 + 10
|
||||||
|
|
||||||
|
def test_supports_args_runtime_dispatch_and_warning(
|
||||||
|
self, caplog_vllm: pytest.LogCaptureFixture
|
||||||
|
):
|
||||||
|
x1 = torch.ones((2, 2), dtype=torch.int32)
|
||||||
|
y1 = torch.full((2, 2), 2, dtype=torch.int32)
|
||||||
|
|
||||||
|
x2 = torch.ones((2, 3), dtype=torch.int32)
|
||||||
|
y2 = torch.full((2, 3), 2, dtype=torch.int32)
|
||||||
|
|
||||||
|
with (
|
||||||
|
caplog_vllm.at_level(logging.WARNING),
|
||||||
|
_custom_add.set_priority(["impl_even"]),
|
||||||
|
):
|
||||||
|
# Test the warning about native fallback is logged (before even dispatching)
|
||||||
|
assert len(caplog_vllm.records) == 1
|
||||||
|
message = caplog_vllm.records[0].message
|
||||||
|
assert "_custom_add" in message
|
||||||
|
assert "fallback to native" in message
|
||||||
|
assert "priority" in message
|
||||||
|
|
||||||
|
# Check dispatching
|
||||||
|
assert _custom_add.get_priority() == ["impl_even", "native"]
|
||||||
|
assert _custom_add.dispatch(x1, y1) is impl_even
|
||||||
|
assert _custom_add.dispatch(x2, y2) is _custom_add.impls["native"]
|
||||||
|
|
||||||
|
out1 = _custom_add(x1, y1) # size(1) == 2 → impl_even
|
||||||
|
out2 = _custom_add(x2, y2) # size(1) == 3 → native fallback
|
||||||
|
|
||||||
|
# no other warnings
|
||||||
|
assert len(caplog_vllm.records) == 1
|
||||||
|
assert torch.all(out1 == 1 + 2 + 50)
|
||||||
|
assert torch.all(out2 == 1 + 2)
|
||||||
|
|
||||||
|
def test_default_priority(
|
||||||
|
self, caplog_vllm: pytest.LogCaptureFixture, disable_log_dedup
|
||||||
|
):
|
||||||
|
# Make sure logs are not deduplicated to properly test the warning
|
||||||
|
x = torch.tensor([3], dtype=torch.int32)
|
||||||
|
y = torch.tensor([4], dtype=torch.int32)
|
||||||
|
|
||||||
|
# No priority set → falls back to native
|
||||||
|
assert _custom_add.get_priority() == []
|
||||||
|
with caplog_vllm.at_level(logging.WARNING):
|
||||||
|
# Native by default
|
||||||
|
assert _custom_add.dispatch(x, y) is _custom_add.impls["native"]
|
||||||
|
out = _custom_add(x, y)
|
||||||
|
|
||||||
|
# Check dispatching to native by default
|
||||||
|
assert out.item() == 3 + 4
|
||||||
|
|
||||||
|
# Check warning
|
||||||
|
assert len(caplog_vllm.records) == 2
|
||||||
|
message = caplog_vllm.records[0].message.lower()
|
||||||
|
assert "_custom_add" in message
|
||||||
|
assert "priority not set" in message
|
||||||
|
|
||||||
|
|
||||||
|
@vllm.ir.register_op
|
||||||
|
def _custom_mm(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
tmp = x @ y
|
||||||
|
return tmp if bias is None else tmp + bias
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_args():
|
||||||
|
# Test that default args are properly applied when dispatching and calling
|
||||||
|
@_custom_mm.register_impl("impl_mm", supports_args=lambda x, y, bias=None: True)
|
||||||
|
def impl_mm(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
tmp = x @ y
|
||||||
|
return tmp + 50 if bias is None else tmp + bias + 100
|
||||||
|
|
||||||
|
x1 = torch.tensor([1, 2], dtype=torch.int32)
|
||||||
|
x2 = torch.tensor([3, 4], dtype=torch.int32)
|
||||||
|
|
||||||
|
# Test that supports_args receives the defaulted args
|
||||||
|
assert impl_mm.supports_args(x1, x2)
|
||||||
|
with _custom_mm.set_priority(["impl_mm", "native"]):
|
||||||
|
assert _custom_mm.dispatch(x1, x2) is impl_mm
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_impl_registrations():
|
||||||
|
# Check bad schema
|
||||||
|
with pytest.raises(ValueError, match="does not match native schema"):
|
||||||
|
|
||||||
|
@_custom_mm.register_impl("impl_mm_bad_schema")
|
||||||
|
def impl_mm_bad_schema(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x @ y - 1
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="does not match native schema"):
|
||||||
|
|
||||||
|
@_custom_mm.register_impl("impl_mm_bad_schema_2")
|
||||||
|
def impl_mm_bad_schema_2(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, b: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return x @ y + b - 2
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="does not match native schema"):
|
||||||
|
|
||||||
|
@_custom_mm.register_impl("impl_mm_bad_schema_3")
|
||||||
|
def impl_mm_bad_schema_3(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return x @ y + bias - 5
|
||||||
|
|
||||||
|
# check supports_args with incorrect params
|
||||||
|
with pytest.raises(ValueError, match="supports_args must be a callable"):
|
||||||
|
|
||||||
|
@_custom_mm.register_impl("impl_mm_bad_supports_args", supports_args=True)
|
||||||
|
def impl_mm_bad_supports_args(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return x @ y + 10
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="number of parameters"):
|
||||||
|
|
||||||
|
@_custom_mm.register_impl(
|
||||||
|
"impl_mm_bad_supports_args_2", supports_args=lambda x, y: True
|
||||||
|
)
|
||||||
|
def impl_mm_bad_supports_args(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return x @ y + 10
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="keyword-only parameters"):
|
||||||
|
|
||||||
|
@_custom_mm.register_impl(
|
||||||
|
"impl_mm_bad_supports_args_3", supports_args=lambda x, y, *, b: True
|
||||||
|
)
|
||||||
|
def impl_mm_bad_supports_args_2(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return x @ y + 20
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="does not match native parameter"):
|
||||||
|
|
||||||
|
@_custom_mm.register_impl(
|
||||||
|
"impl_mm_bad_supports_args_4", supports_args=lambda x, y, b: True
|
||||||
|
)
|
||||||
|
def impl_mm_bad_supports_args_4(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return x @ y + 30
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="does not match native default"):
|
||||||
|
|
||||||
|
@_custom_mm.register_impl(
|
||||||
|
"impl_mm_bad_supports_args_5", supports_args=lambda x, y, bias=1: True
|
||||||
|
)
|
||||||
|
def impl_mm_bad_supports_args_5(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return x @ y + 40
|
||||||
|
|
||||||
|
assert set(_custom_mm.impls.keys()) == {"impl_mm", "native"}
|
||||||
|
|
||||||
|
|
||||||
|
IMPL_OOT_SRC = """
|
||||||
|
import torch
|
||||||
|
|
||||||
|
@_custom_mm.register_impl("impl_mm_oot")
|
||||||
|
def impl_mm_oot(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return x @ y - 99
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def load_custom_mm_module(file_path: Path):
|
||||||
|
spec = importlib.util.spec_from_file_location("_custom_mm_oot", file_path)
|
||||||
|
assert spec is not None
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
|
||||||
|
# Inject the variable into the module's global namespace
|
||||||
|
# This allows the @_custom_mm.register_impl decorator to work
|
||||||
|
module._custom_mm = _custom_mm # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# Execute the file; this triggers the decorator
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def test_uuid_and_oot(tmp_path: Path):
|
||||||
|
file_path = tmp_path / "_custom_mm_oot.py"
|
||||||
|
file_path.write_text(IMPL_OOT_SRC)
|
||||||
|
|
||||||
|
assert "impl_mm_oot" not in _custom_mm.impls
|
||||||
|
_ = load_custom_mm_module(file_path)
|
||||||
|
assert "impl_mm_oot" in _custom_mm.impls
|
||||||
|
|
||||||
|
uuid = _custom_mm.impls["impl_mm_oot"].uuid()
|
||||||
|
del _custom_mm.impls["impl_mm_oot"]
|
||||||
|
|
||||||
|
# Replace file source
|
||||||
|
file_path.write_text(IMPL_OOT_SRC + " # added file source")
|
||||||
|
assert "impl_mm_oot" not in _custom_mm.impls
|
||||||
|
_ = load_custom_mm_module(file_path)
|
||||||
|
assert "impl_mm_oot" in _custom_mm.impls
|
||||||
|
|
||||||
|
uuid1 = _custom_mm.impls["impl_mm_oot"].uuid()
|
||||||
|
assert uuid1 != uuid
|
||||||
|
del _custom_mm.impls["impl_mm_oot"]
|
||||||
|
|
||||||
|
# Back to original
|
||||||
|
file_path.write_text(IMPL_OOT_SRC)
|
||||||
|
assert "impl_mm_oot" not in _custom_mm.impls
|
||||||
|
_ = load_custom_mm_module(file_path)
|
||||||
|
assert "impl_mm_oot" in _custom_mm.impls
|
||||||
|
|
||||||
|
uuid2 = _custom_mm.impls["impl_mm_oot"].uuid()
|
||||||
|
assert uuid2 == uuid
|
||||||
|
assert uuid2 != uuid1
|
||||||
|
del _custom_mm.impls["impl_mm_oot"]
|
||||||
129
tests/kernels/ir/test_layernorm.py
Normal file
129
tests/kernels/ir/test_layernorm.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# This registers op implementations
|
||||||
|
import vllm.kernels # noqa: F401
|
||||||
|
from tests.kernels.allclose_default import get_default_rtol
|
||||||
|
from vllm import ir
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype):
|
||||||
|
x = torch.randn(n_tokens, hidden_size, dtype=dtype)
|
||||||
|
weight = torch.rand(hidden_size, dtype=dtype)
|
||||||
|
return x, weight
|
||||||
|
|
||||||
|
|
||||||
|
rms_norm_native = ir.ops.rms_norm.impls["native"].impl_fn
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
|
||||||
|
reason="Currently only kernels on CUDA, ROCm and XPU",
|
||||||
|
)
|
||||||
|
def test_rms_norm_registration():
|
||||||
|
expected = {
|
||||||
|
"native": True,
|
||||||
|
"vllm_c": current_platform.is_cuda_alike(),
|
||||||
|
"aiter": current_platform.is_rocm(),
|
||||||
|
"oink": False,
|
||||||
|
"xpu_kernels": current_platform.is_xpu(),
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = {
|
||||||
|
provider: impl.supported for provider, impl in ir.ops.rms_norm.impls.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||||
|
@pytest.mark.parametrize("n_tokens", [1, 8, 17])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [16, 4096, 8192])
|
||||||
|
@pytest.mark.parametrize("epsilon", [1e-6, 1e-5])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
|
||||||
|
reason="Currently only kernels on CUDA, ROCm and XPU",
|
||||||
|
)
|
||||||
|
class TestRMSNorm:
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls, **kwargs):
|
||||||
|
torch.set_default_device(current_platform.device_type)
|
||||||
|
|
||||||
|
def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon):
|
||||||
|
x, weight = rms_norm_inputs(4, 8, dtype)
|
||||||
|
out = rms_norm_native(x, weight, epsilon=epsilon)
|
||||||
|
|
||||||
|
# Check shape, dtype, device
|
||||||
|
assert out.shape == x.shape
|
||||||
|
assert out.dtype == x.dtype
|
||||||
|
assert out.device == x.device
|
||||||
|
|
||||||
|
# Check the scaling property of rms norm
|
||||||
|
out2 = rms_norm_native(x * 2.0, weight, epsilon=epsilon)
|
||||||
|
torch.testing.assert_close(out2, out, rtol=get_default_rtol(out), atol=1e-3)
|
||||||
|
|
||||||
|
# Check behavior with and without weight
|
||||||
|
weight1 = torch.ones_like(weight)
|
||||||
|
out3 = rms_norm_native(x, weight1, epsilon=epsilon)
|
||||||
|
out4 = rms_norm_native(x, None, epsilon=epsilon)
|
||||||
|
torch.testing.assert_close(out3, out4)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels"])
|
||||||
|
def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider):
|
||||||
|
impl = ir.ops.rms_norm.impls[provider]
|
||||||
|
if not impl.supported:
|
||||||
|
pytest.skip(f"{provider} impl not supported on this platform")
|
||||||
|
|
||||||
|
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
|
||||||
|
args = (x, weight, epsilon, None)
|
||||||
|
|
||||||
|
assert impl.supported
|
||||||
|
|
||||||
|
if provider == "aiter" and dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
assert not impl.supports_args(*args)
|
||||||
|
return
|
||||||
|
|
||||||
|
assert impl.supports_args(*args)
|
||||||
|
|
||||||
|
out_impl = impl.impl_fn(*args)
|
||||||
|
out_native = rms_norm_native(*args)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
out_impl, out_native, rtol=get_default_rtol(out_impl), atol=1e-3
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that dispatched call matches direct call
|
||||||
|
with ir.ops.rms_norm.set_priority([provider, "native"]):
|
||||||
|
out_impl2 = ir.ops.rms_norm(*args)
|
||||||
|
|
||||||
|
# exact match
|
||||||
|
torch.testing.assert_close(out_impl2, out_impl, rtol=0.0, atol=0.0)
|
||||||
|
|
||||||
|
# none of these support variance_size override
|
||||||
|
assert not impl.supports_args(x, weight, epsilon, 4)
|
||||||
|
assert not impl.supports_args(x, weight, epsilon, variance_size=4)
|
||||||
|
|
||||||
|
# test weight=None behavior
|
||||||
|
out_impl_no_weight = impl.impl_fn(x, None, epsilon)
|
||||||
|
out_impl_unit_weight = impl.impl_fn(x, torch.ones_like(weight), epsilon)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
out_impl_no_weight,
|
||||||
|
out_impl_unit_weight,
|
||||||
|
rtol=get_default_rtol(out_impl_no_weight),
|
||||||
|
atol=2e-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels", "native"])
|
||||||
|
def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider):
|
||||||
|
if not ir.ops.rms_norm.impls[provider].supported:
|
||||||
|
pytest.skip(f"{provider} impl not supported on this platform")
|
||||||
|
|
||||||
|
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
|
||||||
|
args = (x, weight, epsilon, None)
|
||||||
|
|
||||||
|
# When checking the torch op, we have to set priority and use dispatch
|
||||||
|
with ir.ops.rms_norm.set_priority([provider, "native"]):
|
||||||
|
torch.library.opcheck(torch.ops.vllm_ir.rms_norm, args)
|
||||||
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.layernorm import (
|
|||||||
RMSNorm,
|
RMSNorm,
|
||||||
dispatch_rocm_rmsnorm_func,
|
dispatch_rocm_rmsnorm_func,
|
||||||
fused_add_rms_norm,
|
fused_add_rms_norm,
|
||||||
rms_norm,
|
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -156,7 +155,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
|
|||||||
assert topk_func == vllm_topk_sigmoid
|
assert topk_func == vllm_topk_sigmoid
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("add_residual", [True, False])
|
@pytest.mark.parametrize("add_residual", [False])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
|
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@@ -165,7 +164,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
|
|||||||
def test_rms_norm_dispatch(
|
def test_rms_norm_dispatch(
|
||||||
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
|
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
|
||||||
):
|
):
|
||||||
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
|
rms_norm_func = dispatch_rocm_rmsnorm_func(dtype, use_rocm_aiter)
|
||||||
|
|
||||||
should_use_rocm_aiter = (
|
should_use_rocm_aiter = (
|
||||||
current_platform.is_rocm()
|
current_platform.is_rocm()
|
||||||
@@ -173,11 +172,7 @@ def test_rms_norm_dispatch(
|
|||||||
and dtype in RMS_NORM_SUPPORTED_DTYPES
|
and dtype in RMS_NORM_SUPPORTED_DTYPES
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_residual and should_use_rocm_aiter:
|
if should_use_rocm_aiter:
|
||||||
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
|
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
|
||||||
elif should_use_rocm_aiter:
|
|
||||||
assert rms_norm_func == rocm_aiter_ops.rms_norm
|
|
||||||
elif add_residual:
|
|
||||||
assert rms_norm_func == fused_add_rms_norm
|
|
||||||
else:
|
else:
|
||||||
assert rms_norm_func == rms_norm
|
assert rms_norm_func == fused_add_rms_norm
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ import os
|
|||||||
from dataclasses import MISSING, Field, asdict, dataclass, field
|
from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pydantic
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from vllm.compilation.backends import VllmBackend
|
from vllm.compilation.backends import VllmBackend
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CompilationConfig,
|
CompilationConfig,
|
||||||
|
KernelConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
PoolerConfig,
|
PoolerConfig,
|
||||||
@@ -21,6 +23,7 @@ from vllm.config import (
|
|||||||
update_config,
|
update_config,
|
||||||
)
|
)
|
||||||
from vllm.config.compilation import CompilationMode, CUDAGraphMode
|
from vllm.config.compilation import CompilationMode, CUDAGraphMode
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
from vllm.config.load import LoadConfig
|
from vllm.config.load import LoadConfig
|
||||||
from vllm.config.utils import get_field
|
from vllm.config.utils import get_field
|
||||||
from vllm.config.vllm import (
|
from vllm.config.vllm import (
|
||||||
@@ -1077,6 +1080,39 @@ def test_vllm_config_explicit_overrides():
|
|||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
|
||||||
|
|
||||||
|
def test_fusion_pass_op_priority():
|
||||||
|
"""This test checks that custom op enablement & IR op priority
|
||||||
|
correctly control default fusions"""
|
||||||
|
|
||||||
|
# Default config, O2, rms_norm+quant fusion disabled
|
||||||
|
cfg1 = VllmConfig()
|
||||||
|
assert not cfg1.compilation_config.pass_config.fuse_norm_quant
|
||||||
|
|
||||||
|
# rms_norm manually enabled, O1, rms_norm+quant fusion enabled
|
||||||
|
cfg2 = VllmConfig(
|
||||||
|
optimization_level=OptimizationLevel.O1,
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
custom_ops=["+rms_norm"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert cfg2.compilation_config.pass_config.fuse_norm_quant
|
||||||
|
|
||||||
|
# using custom kernel for RMSNorm via IR:
|
||||||
|
# Note that vLLM IR only supports the non-residual rms_norm for now;
|
||||||
|
# soon this will be resolved.
|
||||||
|
cfg3 = VllmConfig(
|
||||||
|
kernel_config=KernelConfig(
|
||||||
|
ir_op_priority=IrOpPriorityConfig(rms_norm=["vllm_c"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert cfg3.compilation_config.pass_config.fuse_norm_quant
|
||||||
|
|
||||||
|
# block-fp8 model should enable quant_fp8 automatically
|
||||||
|
cfg4 = VllmConfig(model_config=ModelConfig("Qwen/Qwen3-4B-FP8"))
|
||||||
|
assert "+quant_fp8" in cfg4.compilation_config.custom_ops
|
||||||
|
assert cfg4.compilation_config.pass_config.fuse_norm_quant
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_config_init():
|
def test_scheduler_config_init():
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
# Positional InitVars missing
|
# Positional InitVars missing
|
||||||
@@ -1171,3 +1207,35 @@ def test_eagle_draft_model_config():
|
|||||||
assert draft_model_config.hf_text_config.model_type == "eagle"
|
assert draft_model_config.hf_text_config.model_type == "eagle"
|
||||||
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
|
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
|
||||||
assert draft_model_config.architecture == "EagleLlamaForCausalLM"
|
assert draft_model_config.architecture == "EagleLlamaForCausalLM"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ir_op_priority_default():
|
||||||
|
"""Test that IR op priority defaults are set correctly."""
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
|
|
||||||
|
# Assert default is applied to ops
|
||||||
|
priority_config = IrOpPriorityConfig.with_default(["vllm_c", "native"])
|
||||||
|
assert priority_config.rms_norm == ["vllm_c", "native"]
|
||||||
|
|
||||||
|
# Assert single ops override the default
|
||||||
|
assert IrOpPriorityConfig.with_default(
|
||||||
|
["vllm_c", "native"], rms_norm=["oink", "native"]
|
||||||
|
) == IrOpPriorityConfig(rms_norm=["oink", "native"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_ir_op_priority_str():
|
||||||
|
"""Test that passing a comma-delimited string works"""
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
|
|
||||||
|
priority_config = IrOpPriorityConfig(rms_norm="vllm_c")
|
||||||
|
assert priority_config.rms_norm == ["vllm_c"]
|
||||||
|
|
||||||
|
priority_config = IrOpPriorityConfig(rms_norm="vllm_c,native")
|
||||||
|
assert priority_config.rms_norm == ["vllm_c", "native"]
|
||||||
|
|
||||||
|
priority_config = IrOpPriorityConfig(rms_norm=" native, vllm_c ")
|
||||||
|
assert priority_config.rms_norm == ["native", "vllm_c"]
|
||||||
|
|
||||||
|
with pytest.raises(pydantic.ValidationError):
|
||||||
|
# must be list of only strings
|
||||||
|
priority_config = IrOpPriorityConfig(rms_norm=["vllm_c", 4, "native"])
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.pattern_matcher as pm
|
import torch._inductor.pattern_matcher as pm
|
||||||
@@ -10,6 +11,7 @@ import torch.fx as fx
|
|||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
|
|
||||||
|
import vllm.ir.ops
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.utils import Range
|
from vllm.config.utils import Range
|
||||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||||
@@ -28,7 +30,7 @@ from vllm.utils.torch_utils import (
|
|||||||
|
|
||||||
from ..inductor_pass import enable_fake_mode
|
from ..inductor_pass import enable_fake_mode
|
||||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
@@ -258,6 +260,12 @@ class BasePattern:
|
|||||||
self.tp = get_tp_group()
|
self.tp = get_tp_group()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
|
return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs)
|
||||||
|
|
||||||
|
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
|
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AllReduceRMSNormPattern(BasePattern):
|
class AllReduceRMSNormPattern(BasePattern):
|
||||||
"""
|
"""
|
||||||
@@ -276,20 +284,17 @@ class AllReduceRMSNormPattern(BasePattern):
|
|||||||
super().__init__(dtype, device)
|
super().__init__(dtype, device)
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.allreduce_params = allreduce_params
|
self.allreduce_params = allreduce_params
|
||||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
|
||||||
|
|
||||||
def get_inputs(self) -> list[torch.Tensor]:
|
def get_inputs(self) -> list[torch.Tensor]:
|
||||||
input, weight = self.rmsnorm_matcher.inputs()
|
# input, weight
|
||||||
|
return [self.empty(5, 16), self.empty(16)]
|
||||||
# input goes through allreduce first, always 16-bit
|
|
||||||
return [input.to(self.dtype), weight]
|
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||||
def pattern(
|
def pattern(
|
||||||
input: torch.Tensor, weight: torch.Tensor
|
input: torch.Tensor, weight: torch.Tensor
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
rms = self.rmsnorm_matcher(allreduce_output, weight)
|
rms = vllm.ir.ops.rms_norm(allreduce_output, weight, self.epsilon)
|
||||||
|
|
||||||
return rms, allreduce_output
|
return rms, allreduce_output
|
||||||
|
|
||||||
@@ -407,15 +412,13 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
|||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.allreduce_params = allreduce_params
|
self.allreduce_params = allreduce_params
|
||||||
self.quant_dtype = torch.float8_e4m3fn
|
self.quant_dtype = torch.float8_e4m3fn
|
||||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
|
||||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||||
|
|
||||||
def get_inputs(self) -> list[torch.Tensor]:
|
def get_inputs(self) -> list[torch.Tensor]:
|
||||||
input, weight = self.rmsnorm_matcher.inputs()
|
|
||||||
_, scale = self.quant_matcher.inputs()
|
_, scale = self.quant_matcher.inputs()
|
||||||
|
|
||||||
# input goes through allreduce first, always 16-bit
|
# input, weight
|
||||||
return [input.to(self.dtype), weight, scale]
|
return [self.empty(5, 16), self.empty(16), scale]
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||||
def pattern(
|
def pattern(
|
||||||
@@ -424,7 +427,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
|||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
|
||||||
quant, _ = self.quant_matcher(rms, scale)
|
quant, _ = self.quant_matcher(rms, scale)
|
||||||
return quant, all_reduce
|
return quant, all_reduce
|
||||||
|
|
||||||
@@ -553,7 +556,6 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
|||||||
super().__init__(dtype, device)
|
super().__init__(dtype, device)
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.allreduce_params = allreduce_params
|
self.allreduce_params = allreduce_params
|
||||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
|
||||||
|
|
||||||
def get_inputs(self) -> list[torch.Tensor]:
|
def get_inputs(self) -> list[torch.Tensor]:
|
||||||
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
|
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
|
||||||
@@ -575,7 +577,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
|||||||
output_scale: torch.Tensor,
|
output_scale: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
|
||||||
quant_out_tuple = auto_functionalized(
|
quant_out_tuple = auto_functionalized(
|
||||||
STATIC_FP4_QUANT_OP,
|
STATIC_FP4_QUANT_OP,
|
||||||
input=rms,
|
input=rms,
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
RMS_OP = torch.ops._C.rms_norm.default
|
|
||||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||||
ROTARY_OP = torch.ops._C.rotary_embedding.default
|
ROTARY_OP = torch.ops._C.rotary_embedding.default
|
||||||
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
|
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
|
||||||
@@ -160,69 +159,6 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class MatcherRMSNorm(MatcherCustomOp):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
epsilon: float,
|
|
||||||
enabled: bool | None = None,
|
|
||||||
match_rocm_aiter: bool = False,
|
|
||||||
) -> None:
|
|
||||||
if enabled is None:
|
|
||||||
enabled = RMSNorm.enabled()
|
|
||||||
|
|
||||||
super().__init__(enabled)
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self._rmsnorm_op = RMS_OP
|
|
||||||
self.match_rocm_aiter = match_rocm_aiter
|
|
||||||
|
|
||||||
if match_rocm_aiter:
|
|
||||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
|
|
||||||
|
|
||||||
def inputs(self) -> list[torch.Tensor]:
|
|
||||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
|
||||||
weight = self.empty(16)
|
|
||||||
return [input, weight]
|
|
||||||
|
|
||||||
def forward_rocm_aiter(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return self._rmsnorm_op(
|
|
||||||
x=input,
|
|
||||||
weight=weight,
|
|
||||||
variance_epsilon=self.epsilon,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_custom(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if self.match_rocm_aiter:
|
|
||||||
return self.forward_rocm_aiter(input, weight)
|
|
||||||
|
|
||||||
result = torch.empty_like(input)
|
|
||||||
_, result = auto_functionalized(
|
|
||||||
self._rmsnorm_op,
|
|
||||||
result=result,
|
|
||||||
input=input,
|
|
||||||
weight=weight,
|
|
||||||
epsilon=self.epsilon,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def forward_native(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return RMSNorm.forward_static(
|
|
||||||
input, self.epsilon, input.size(-1), self.model_dtype, weight
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from torch import fx
|
|||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
|
|
||||||
|
import vllm.ir.ops
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.attention import Attention
|
from vllm.model_executor.layers.attention import Attention
|
||||||
@@ -17,7 +18,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
|||||||
|
|
||||||
from ..inductor_pass import enable_fake_mode
|
from ..inductor_pass import enable_fake_mode
|
||||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
|
from .matcher_utils import MatcherRotaryEmbedding
|
||||||
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
|
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -64,7 +65,6 @@ class QkNormRopePattern:
|
|||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.rmsnorm_matcher = MatcherRMSNorm(eps)
|
|
||||||
self.is_neox = is_neox
|
self.is_neox = is_neox
|
||||||
self.rope_flashinfer = rope_flashinfer
|
self.rope_flashinfer = rope_flashinfer
|
||||||
self.rope_matcher = MatcherRotaryEmbedding(
|
self.rope_matcher = MatcherRotaryEmbedding(
|
||||||
@@ -129,14 +129,14 @@ class QkNormRopePattern:
|
|||||||
q_by_head = q.view(
|
q_by_head = q.view(
|
||||||
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
|
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
|
||||||
)
|
)
|
||||||
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
|
q_normed_by_head = vllm.ir.ops.rms_norm(q_by_head, q_weight, self.eps)
|
||||||
q_flat = q_normed_by_head.view(q.shape)
|
q_flat = q_normed_by_head.view(q.shape)
|
||||||
|
|
||||||
# K path: view -> RMS -> view back to k.shape
|
# K path: view -> RMS -> view back to k.shape
|
||||||
k_by_head = k.view(
|
k_by_head = k.view(
|
||||||
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
|
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
|
||||||
)
|
)
|
||||||
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
|
k_normed_by_head = vllm.ir.ops.rms_norm(k_by_head, k_weight, self.eps)
|
||||||
k_flat = k_normed_by_head.view(k.shape)
|
k_flat = k_normed_by_head.view(k.shape)
|
||||||
|
|
||||||
# RoPE: apply to flattened q/k
|
# RoPE: apply to flattened q/k
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|||||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
|
import vllm.ir.ops
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
@@ -30,7 +31,6 @@ from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
|||||||
from .matcher_utils import (
|
from .matcher_utils import (
|
||||||
MatcherFusedAddRMSNorm,
|
MatcherFusedAddRMSNorm,
|
||||||
MatcherQuantFP8,
|
MatcherQuantFP8,
|
||||||
MatcherRMSNorm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -54,7 +54,6 @@ def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
|
|||||||
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
|
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
RMS_OP = torch.ops._C.rms_norm.default
|
|
||||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||||
|
|
||||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||||
@@ -131,11 +130,9 @@ class RMSNormQuantPattern:
|
|||||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||||
self.FUSED_OP = FUSED_OPS[key]
|
self.FUSED_OP = FUSED_OPS[key]
|
||||||
|
|
||||||
self.rmsnorm_matcher = (
|
if key.fused_add:
|
||||||
MatcherRMSNorm(epsilon)
|
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||||
if not key.fused_add
|
|
||||||
else MatcherFusedAddRMSNorm(epsilon)
|
|
||||||
)
|
|
||||||
self.quant_matcher = MatcherQuantFP8(
|
self.quant_matcher = MatcherQuantFP8(
|
||||||
key.quant,
|
key.quant,
|
||||||
has_col_major_scales=has_col_major_scales,
|
has_col_major_scales=has_col_major_scales,
|
||||||
@@ -161,16 +158,12 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
|||||||
def pattern(
|
def pattern(
|
||||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
result_rms = self.rmsnorm_matcher(input, weight)
|
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
|
||||||
return self.quant_matcher(result_rms, scale)[0]
|
return self.quant_matcher(result_rms, scale)[0]
|
||||||
|
|
||||||
def replacement(
|
def replacement(
|
||||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# In case we're matching native rms-norm, conversions might be
|
|
||||||
# optimized out. We convert here just to be safe.
|
|
||||||
input = input.to(dtype=self.model_dtype)
|
|
||||||
|
|
||||||
result = torch.empty(
|
result = torch.empty(
|
||||||
input.shape, device=input.device, dtype=self.quant_dtype
|
input.shape, device=input.device, dtype=self.quant_dtype
|
||||||
)
|
)
|
||||||
@@ -187,8 +180,8 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
|||||||
return at[1]
|
return at[1]
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
# input, weight
|
empty_bf16(5, 16), # input
|
||||||
*self.rmsnorm_matcher.inputs(),
|
empty_bf16(16), # weight
|
||||||
self.quant_matcher.inputs()[1], # scale
|
self.quant_matcher.inputs()[1], # scale
|
||||||
]
|
]
|
||||||
pattern(*inputs)
|
pattern(*inputs)
|
||||||
@@ -391,7 +384,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
|||||||
def pattern(
|
def pattern(
|
||||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
result_rms = self.rmsnorm_matcher(input, weight)
|
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
|
||||||
result = torch.empty(
|
result = torch.empty(
|
||||||
result_rms.shape,
|
result_rms.shape,
|
||||||
device=result_rms.device,
|
device=result_rms.device,
|
||||||
@@ -442,12 +435,14 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
|||||||
# result, scale
|
# result, scale
|
||||||
return at[1], at[2]
|
return at[1], at[2]
|
||||||
|
|
||||||
scale = self.quant_matcher.empty_f32(1, 1)
|
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern,
|
pattern,
|
||||||
replacement,
|
replacement,
|
||||||
self.rmsnorm_matcher.inputs() + [scale],
|
[
|
||||||
|
empty_bf16(5, 16), # input
|
||||||
|
empty_bf16(16), # weight
|
||||||
|
self.quant_matcher.empty_f32(1, 1), # scale
|
||||||
|
],
|
||||||
pm.fwd_only,
|
pm.fwd_only,
|
||||||
pm_pass,
|
pm_pass,
|
||||||
)
|
)
|
||||||
@@ -472,7 +467,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
|||||||
def pattern(
|
def pattern(
|
||||||
input: torch.Tensor, weight: torch.Tensor
|
input: torch.Tensor, weight: torch.Tensor
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
result_rms = self.rmsnorm_matcher(input, weight)
|
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
|
||||||
# result, scale
|
# result, scale
|
||||||
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
|
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
|
||||||
|
|
||||||
@@ -502,7 +497,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
|||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern,
|
pattern,
|
||||||
replacement,
|
replacement,
|
||||||
self.rmsnorm_matcher.inputs(),
|
[
|
||||||
|
empty_bf16(5, 16), # input
|
||||||
|
empty_bf16(16), # weight
|
||||||
|
],
|
||||||
pm.fwd_only,
|
pm.fwd_only,
|
||||||
pm_pass,
|
pm_pass,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.pattern_matcher as pm
|
import torch._inductor.pattern_matcher as pm
|
||||||
@@ -24,7 +25,6 @@ from .act_quant_fusion import ActivationQuantPattern
|
|||||||
from .matcher_utils import (
|
from .matcher_utils import (
|
||||||
MatcherFusedAddRMSNorm,
|
MatcherFusedAddRMSNorm,
|
||||||
MatcherQuantFP8,
|
MatcherQuantFP8,
|
||||||
MatcherRMSNorm,
|
|
||||||
MatcherSiluAndMul,
|
MatcherSiluAndMul,
|
||||||
)
|
)
|
||||||
from .rms_quant_fusion import (
|
from .rms_quant_fusion import (
|
||||||
@@ -41,17 +41,23 @@ class AiterRMSNormQuantPattern:
|
|||||||
):
|
):
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.quant_dtype = key.quant.dtype
|
self.quant_dtype = key.quant.dtype
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
|
||||||
self.rmsnorm_matcher = (
|
if key.fused_add:
|
||||||
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
|
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(
|
||||||
if not key.fused_add
|
epsilon, match_rocm_aiter=True
|
||||||
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
)
|
||||||
)
|
|
||||||
self.quant_matcher = MatcherQuantFP8(
|
self.quant_matcher = MatcherQuantFP8(
|
||||||
key.quant,
|
key.quant,
|
||||||
match_rocm_aiter=match_aiter_quant,
|
match_rocm_aiter=match_aiter_quant,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
|
return torch.empty(*args, dtype=torch.bfloat16, device=self.device, **kwargs)
|
||||||
|
|
||||||
|
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
|
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||||
"""AITER RMSNorm + Dynamic Quantization pattern."""
|
"""AITER RMSNorm + Dynamic Quantization pattern."""
|
||||||
@@ -79,7 +85,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
|||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
result_rms = self.rmsnorm_matcher(input, weight)
|
result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon)
|
||||||
result, scale = self.quant_matcher(result_rms)
|
result, scale = self.quant_matcher(result_rms)
|
||||||
return result, scale
|
return result, scale
|
||||||
|
|
||||||
@@ -99,7 +105,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
|||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern,
|
pattern,
|
||||||
replacement,
|
replacement,
|
||||||
self.rmsnorm_matcher.inputs(),
|
# input, weight
|
||||||
|
[self.empty(5, 16), self.empty(16)],
|
||||||
pm.fwd_only,
|
pm.fwd_only,
|
||||||
pm_pass,
|
pm_pass,
|
||||||
)
|
)
|
||||||
@@ -188,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
|||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
result_rms = self.rmsnorm_matcher(input, weight)
|
result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon)
|
||||||
result, scale = self.quant_matcher(result_rms)
|
result, scale = self.quant_matcher(result_rms)
|
||||||
return result, scale
|
return result, scale
|
||||||
|
|
||||||
@@ -206,7 +213,12 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
|||||||
return at[0], at[1]
|
return at[0], at[1]
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
pattern,
|
||||||
|
replacement,
|
||||||
|
# input, weight
|
||||||
|
[self.empty(5, 16), self.empty(16)],
|
||||||
|
pm.fwd_only,
|
||||||
|
pm_pass,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import torch._inductor.pattern_matcher as pm
|
|||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
|
|
||||||
|
import vllm.ir.ops
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.utils import Range
|
from vllm.config.utils import Range
|
||||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||||
@@ -22,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from ..inductor_pass import enable_fake_mode
|
from ..inductor_pass import enable_fake_mode
|
||||||
from ..utility.noop_elimination import NoOpEliminationPass
|
from ..utility.noop_elimination import NoOpEliminationPass
|
||||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -122,35 +123,38 @@ class _SequenceParallelPatternHelper:
|
|||||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
|
return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs)
|
||||||
|
|
||||||
|
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
|
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||||
super().__init__(epsilon, dtype, device)
|
super().__init__(epsilon, dtype, device)
|
||||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
|
||||||
|
|
||||||
def get_inputs(self) -> list[torch.Tensor]:
|
def get_inputs(self) -> list[torch.Tensor]:
|
||||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
# input, weight
|
||||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
return [self.empty([1, 8, 4]), self.empty([4])]
|
||||||
|
|
||||||
return [input, arg3_1]
|
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||||
def pattern(
|
def pattern(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
arg3_1: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
all_reduce = self._all_reduce(input)
|
all_reduce = self._all_reduce(input)
|
||||||
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
|
rmsnorm = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
|
||||||
|
|
||||||
return rmsnorm, all_reduce
|
return rmsnorm, all_reduce
|
||||||
|
|
||||||
def replacement(
|
def replacement(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
arg3_1: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
reduce_scatter = self._reduce_scatter(input)
|
reduce_scatter = self._reduce_scatter(input)
|
||||||
|
|
||||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
|
rmsnorm = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
|
||||||
all_gather = self._all_gather(rmsnorm)
|
all_gather = self._all_gather(rmsnorm)
|
||||||
return all_gather, reduce_scatter
|
return all_gather, reduce_scatter
|
||||||
|
|
||||||
@@ -222,14 +226,11 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
|||||||
device: str | None,
|
device: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(epsilon, dtype, device)
|
super().__init__(epsilon, dtype, device)
|
||||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
|
||||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||||
|
|
||||||
def get_inputs(self) -> list[torch.Tensor]:
|
def get_inputs(self) -> list[torch.Tensor]:
|
||||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
# input, weight, scale
|
||||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
return [self.empty([1, 8, 4]), self.empty([4]), self.empty_f32([1, 1])]
|
||||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
|
||||||
return [input, weight, scale]
|
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||||
def pattern(
|
def pattern(
|
||||||
@@ -238,7 +239,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
|||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
all_reduce = self._all_reduce(input)
|
all_reduce = self._all_reduce(input)
|
||||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
|
||||||
quant, _ = self.quant_matcher(rms, scale)
|
quant, _ = self.quant_matcher(rms, scale)
|
||||||
return quant, all_reduce
|
return quant, all_reduce
|
||||||
|
|
||||||
@@ -248,7 +249,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
|||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
reduce_scatter = self._reduce_scatter(input)
|
reduce_scatter = self._reduce_scatter(input)
|
||||||
rms = self.rmsnorm_matcher(reduce_scatter, weight)
|
rms = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
|
||||||
quant, _ = self.quant_matcher(rms, scale)
|
quant, _ = self.quant_matcher(rms, scale)
|
||||||
all_gather = self._all_gather(quant)
|
all_gather = self._all_gather(quant)
|
||||||
|
|
||||||
|
|||||||
0
vllm/compilation/passes/ir/__init__.py
Normal file
0
vllm/compilation/passes/ir/__init__.py
Normal file
158
vllm/compilation/passes/ir/lowering_pass.py
Normal file
158
vllm/compilation/passes/ir/lowering_pass.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
from torch import fx
|
||||||
|
from torch._inductor.pattern_matcher import (
|
||||||
|
CallFunctionVarArgs,
|
||||||
|
Match,
|
||||||
|
PatternMatcherPass,
|
||||||
|
register_graph_pattern,
|
||||||
|
)
|
||||||
|
from torch._ops import OpOverload, OpOverloadPacket
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.ir.op import IrOp
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.logging_utils import lazy
|
||||||
|
|
||||||
|
from ..vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_overload(op: OpOverload | OpOverloadPacket) -> OpOverload:
|
||||||
|
if isinstance(op, OpOverloadPacket):
|
||||||
|
return op.default
|
||||||
|
assert isinstance(op, OpOverload), "Expected an OpOverload or OpOverloadPacket"
|
||||||
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
def get_ir_op(node: fx.Node) -> IrOp | None:
|
||||||
|
if node.op != "call_function":
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(node.target, (OpOverload, OpOverloadPacket)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
op_overload = get_default_overload(node.target)
|
||||||
|
if op_overload.namespace != "vllm_ir":
|
||||||
|
return None
|
||||||
|
|
||||||
|
op_name = op_overload._opname
|
||||||
|
if op_name not in IrOp.registry:
|
||||||
|
logger.warning(
|
||||||
|
"Unknown vLLM IR op %s, there's likely an issue with torch registration, "
|
||||||
|
"or a torch custom op was registered in the vllm_ir namespace by mistake.",
|
||||||
|
op_name,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
ir_op = IrOp.registry[op_name]
|
||||||
|
return ir_op
|
||||||
|
|
||||||
|
|
||||||
|
class VllmIRLoweringPass(VllmInductorPass):
|
||||||
|
"""
|
||||||
|
This pass lowers vLLM IR ops to their implementations the priority list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||||
|
super().__init__(vllm_config)
|
||||||
|
self.patterns = PatternMatcherPass(self.pass_name)
|
||||||
|
self.selected_impls: dict[str, dict[str, str]] = defaultdict(lambda: {})
|
||||||
|
self.ops = [ir_op.torch_op for ir_op in IrOp.registry.values()]
|
||||||
|
|
||||||
|
# Look for any call_function node where the target is a vLLM IR op.
|
||||||
|
# Then, lower_matched_op will select, trace, and insert the implementation.
|
||||||
|
register_graph_pattern(
|
||||||
|
CallFunctionVarArgs(self.ops),
|
||||||
|
pass_dict=self.patterns,
|
||||||
|
)(self.lower_matched_op)
|
||||||
|
|
||||||
|
def lower_matched_op(self, match: Match, *args, **kwargs):
|
||||||
|
# TODO(luka) I think args and kwargs are for the match, but just use the node?
|
||||||
|
|
||||||
|
assert len(match.nodes) == 1, "Expected single node match"
|
||||||
|
node = match.nodes[0]
|
||||||
|
ir_op = get_ir_op(node)
|
||||||
|
assert ir_op is not None, "Expected vLLM IR op"
|
||||||
|
assert not node.kwargs # I think there should never be kwargs here
|
||||||
|
|
||||||
|
# Select and record the implementation, using fake args
|
||||||
|
fake_args = fx.map_arg(node.args, lambda arg: arg.meta["val"])
|
||||||
|
ir_op_impl = ir_op.dispatch(*fake_args)
|
||||||
|
self.selected_impls[ir_op.name][node.name] = ir_op_impl.provider
|
||||||
|
|
||||||
|
# replace_by_example wants node args, not the fake tensors
|
||||||
|
# TODO(luka): Use aot_export_module to get functionalized graph
|
||||||
|
# TODO(luka): Cache the fx_replacement to avoid re-tracing the same impl
|
||||||
|
|
||||||
|
# Defaults not present on node.args but required for replacement tracing
|
||||||
|
bound_args = ir_op._py_signature.bind(*node.args)
|
||||||
|
bound_args.apply_defaults()
|
||||||
|
match.replace_by_example(ir_op_impl.impl_fn, bound_args.args)
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
|
def __call__(self, graph: fx.Graph) -> None:
|
||||||
|
# clear at the beginning instead of end, so that tests can inspect
|
||||||
|
self.selected_impls.clear()
|
||||||
|
|
||||||
|
count = self.patterns.apply(graph)
|
||||||
|
logger.debug("VllmIRLoweringPass lowered %d vLLM IR nodes", count)
|
||||||
|
|
||||||
|
# TODO write self.selected_impls to depyf/tlparse dir
|
||||||
|
def count_items(impls: Iterable[str]) -> dict[str, int]:
|
||||||
|
counts: dict[str, int] = defaultdict(lambda: 0)
|
||||||
|
for impl in impls:
|
||||||
|
counts[impl] += 1
|
||||||
|
return counts
|
||||||
|
|
||||||
|
def print_count(counts: dict[str, int]) -> str:
|
||||||
|
# e.g., "impl1*3,impl2"
|
||||||
|
impl_count = lambda i, c: f"{i}" if c == 1 else f"{i}*{c}"
|
||||||
|
return ",".join(impl_count(impl, count) for impl, count in counts.items())
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Selected implementations: %s",
|
||||||
|
lazy(
|
||||||
|
lambda: ", ".join(
|
||||||
|
f"{op}={print_count(count_items(impls_by_node.values()))}"
|
||||||
|
for op, impls_by_node in self.selected_impls.items()
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
failed_nodes: list[fx.Node] = []
|
||||||
|
failed_ops: set[str] = set()
|
||||||
|
# Check no vllm_ir nodes were left in the graph
|
||||||
|
for node in graph.nodes:
|
||||||
|
if (ir_op := get_ir_op(node)) is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
failed_nodes.append(node)
|
||||||
|
failed_ops.add(ir_op.name)
|
||||||
|
|
||||||
|
if failed_nodes or failed_ops:
|
||||||
|
logger.warning("Failed to lower vLLM IR ops: %s", ",".join(failed_ops))
|
||||||
|
logger.warning("Full node list: %s", failed_nodes)
|
||||||
|
|
||||||
|
def uuid(self) -> str:
|
||||||
|
"""
|
||||||
|
IR op priority & impl sources affect lowering pass output,
|
||||||
|
so we include them in the cache key.
|
||||||
|
"""
|
||||||
|
priorities = {name: op.get_priority() for name, op in IrOp.registry.items()}
|
||||||
|
priorities_str = ";".join(
|
||||||
|
f"{name}={','.join(p)}" for name, p in priorities.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
impl_uuids_str = ";".join(
|
||||||
|
f"{name}={
|
||||||
|
','.join(IrOp.registry[name].impls[provider].uuid() for provider in p)
|
||||||
|
}"
|
||||||
|
for name, p in priorities.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"{super().uuid()}|{priorities_str}|{impl_uuids_str}"
|
||||||
@@ -14,6 +14,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.system_utils import set_env_var
|
from vllm.utils.system_utils import set_env_var
|
||||||
|
|
||||||
|
from .ir.lowering_pass import VllmIRLoweringPass
|
||||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
|
|
||||||
if rocm_aiter_ops.is_enabled():
|
if rocm_aiter_ops.is_enabled():
|
||||||
@@ -99,8 +100,17 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
|||||||
else:
|
else:
|
||||||
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
|
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
|
||||||
|
|
||||||
# post-cleanup goes before fix_functionalization
|
# perform the first post-cleanup before IR lowering to clean up fusion artifacts
|
||||||
# because it requires a functional graph
|
# and make sure no dead IR ops are lowered.
|
||||||
|
self.post_cleanup(graph)
|
||||||
|
VllmInductorPass.dump_prefix += 1
|
||||||
|
|
||||||
|
# lowering before cleanup so DCE can clean up lowered ops.
|
||||||
|
# DCE handles mutating ops correctly as well.
|
||||||
|
self.ir_lowering(graph)
|
||||||
|
VllmInductorPass.dump_prefix += 1
|
||||||
|
|
||||||
|
# clean up after lowering again
|
||||||
self.post_cleanup(graph)
|
self.post_cleanup(graph)
|
||||||
VllmInductorPass.dump_prefix += 1
|
VllmInductorPass.dump_prefix += 1
|
||||||
|
|
||||||
@@ -152,7 +162,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
|||||||
self.passes += [SplitCoalescingPass(config)]
|
self.passes += [SplitCoalescingPass(config)]
|
||||||
self.passes += [QKNormRoPEFusionPass(config)]
|
self.passes += [QKNormRoPEFusionPass(config)]
|
||||||
|
|
||||||
# needs a functional graph
|
self.ir_lowering = VllmIRLoweringPass(config)
|
||||||
self.post_cleanup = PostCleanupPass(config)
|
self.post_cleanup = PostCleanupPass(config)
|
||||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||||
|
|
||||||
@@ -171,6 +181,10 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
|||||||
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
|
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
|
||||||
for pass_ in self.passes:
|
for pass_ in self.passes:
|
||||||
passes.append(pass_.uuid())
|
passes.append(pass_.uuid())
|
||||||
|
|
||||||
|
passes.append(self.post_cleanup.uuid())
|
||||||
|
passes.append(self.ir_lowering.uuid())
|
||||||
|
passes.append(self.post_cleanup.uuid())
|
||||||
passes.append(self.fix_functionalization.uuid())
|
passes.append(self.fix_functionalization.uuid())
|
||||||
|
|
||||||
# Include the compile range in the uuid to ensure that inductor
|
# Include the compile range in the uuid to ensure that inductor
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ class VllmPatternMatcherPass(VllmInductorPass):
|
|||||||
f"auto_functionalized as auto_functionalized\n"
|
f"auto_functionalized as auto_functionalized\n"
|
||||||
f"from torch._inductor.pattern_matcher import *\n"
|
f"from torch._inductor.pattern_matcher import *\n"
|
||||||
f"vllm = torch.ops.vllm",
|
f"vllm = torch.ops.vllm",
|
||||||
|
"vllm_ir = torch.ops.vllm_ir",
|
||||||
file=f,
|
file=f,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -466,6 +466,15 @@ class CompilationConfig:
|
|||||||
disabled when running with Inductor: mode>CompilationMode.NONE and
|
disabled when running with Inductor: mode>CompilationMode.NONE and
|
||||||
backend="inductor".
|
backend="inductor".
|
||||||
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
||||||
|
|
||||||
|
ir_enable_torch_wrap: bool = None # type: ignore[assignment]
|
||||||
|
"""If True, enable vllm_ir torch custom op wrapping during the forward pass.
|
||||||
|
When False, torch custom op wrapping is disabled, allowing Dynamo to trace the
|
||||||
|
selected implementation directly or avoiding torch custom op overhead in eager mode.
|
||||||
|
Defaults to True when using Inductor with vllm-compile
|
||||||
|
(backend=="inductor" and mode == VLLM_COMPILE), False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
splitting_ops: list[str] | None = None
|
splitting_ops: list[str] | None = None
|
||||||
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
|
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
|
||||||
|
|
||||||
@@ -830,6 +839,7 @@ class CompilationConfig:
|
|||||||
"cudagraph_mode",
|
"cudagraph_mode",
|
||||||
"max_cudagraph_capture_size",
|
"max_cudagraph_capture_size",
|
||||||
"use_inductor_graph_partition",
|
"use_inductor_graph_partition",
|
||||||
|
"ir_enable_torch_wrap",
|
||||||
mode="wrap",
|
mode="wrap",
|
||||||
)
|
)
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,13 +1,106 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import contextlib
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, Literal
|
from dataclasses import asdict, fields
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
from pydantic import field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
|
from vllm.config.utils import config, get_hash_factors, hash_factors
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
class IrOpPriorityConfig:
|
||||||
|
"""
|
||||||
|
Configuration for vLLM IR op priority for dispatching/lowering during the
|
||||||
|
forward pass. Each member is a list of strings, which will be passed to
|
||||||
|
vllm.ir.ops.<op_name>.set_priority() for the duration of the forward pass.
|
||||||
|
A single comma-separated string is accepted as well,
|
||||||
|
|
||||||
|
If specified manually, platform defaults will be appended to the lists.
|
||||||
|
See KernelConfig.set_platform_defaults().
|
||||||
|
"""
|
||||||
|
|
||||||
|
rms_norm: list[str] = Field(default_factory=list)
|
||||||
|
"""Priority list for vllm.ir.ops.rms_norm"""
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
Produces a hash unique to the pass configuration.
|
||||||
|
Any new fields that affect compilation should be added to the hash.
|
||||||
|
Any future fields that don't affect compilation should be excluded.
|
||||||
|
|
||||||
|
Also, manually add IR op impl UUIDs to make sure they affect the compile cache.
|
||||||
|
"""
|
||||||
|
factors = get_hash_factors(self, set())
|
||||||
|
|
||||||
|
# Implementations are hidden from Dynamo,
|
||||||
|
# so they don't show up in the traced files list.
|
||||||
|
from vllm.ir.op import IrOp
|
||||||
|
|
||||||
|
assert "_impls" not in factors
|
||||||
|
factors["_impls"] = {
|
||||||
|
name: {
|
||||||
|
provider: IrOp.registry[name].impls[provider].uuid() for provider in p
|
||||||
|
}
|
||||||
|
for name, p in asdict(self).items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return hash_factors(factors)
|
||||||
|
|
||||||
|
@field_validator("*", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _to_list_str(cls, value: str | list[str]):
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = value.replace(" ", "").split(",")
|
||||||
|
|
||||||
|
assert all(isinstance(v, str) for v in value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_priority(self):
|
||||||
|
"""
|
||||||
|
Context manager to set the IR op priority for all op members.
|
||||||
|
It also imports vllm.kernels to ensure all implementations are made available.
|
||||||
|
"""
|
||||||
|
import vllm.kernels # noqa: F401, registers IR op implementations
|
||||||
|
from vllm.ir.op import IrOp
|
||||||
|
|
||||||
|
with contextlib.ExitStack() as stack:
|
||||||
|
for field in fields(self):
|
||||||
|
op_priority = getattr(self, field.name)
|
||||||
|
assert op_priority is not None, (
|
||||||
|
f"IR op priority for {field.name} must be set"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Setting IR op priority for %s to %s", field.name, op_priority
|
||||||
|
)
|
||||||
|
ir_op = IrOp.registry[field.name]
|
||||||
|
stack.enter_context(ir_op.set_priority(op_priority))
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def with_default(
|
||||||
|
cls, default: list[str], /, **kwargs: list[str]
|
||||||
|
) -> "IrOpPriorityConfig":
|
||||||
|
"""
|
||||||
|
A helper to create an IrOpPriorityConfig where fields not specified in kwargs
|
||||||
|
use the given default list.
|
||||||
|
"""
|
||||||
|
for field in fields(cls):
|
||||||
|
if field.name not in kwargs:
|
||||||
|
kwargs[field.name] = list(default)
|
||||||
|
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
from vllm.config.utils import config
|
|
||||||
from vllm.utils.hashing import safe_hash
|
|
||||||
|
|
||||||
MoEBackend = Literal[
|
MoEBackend = Literal[
|
||||||
"auto",
|
"auto",
|
||||||
@@ -26,6 +119,12 @@ MoEBackend = Literal[
|
|||||||
class KernelConfig:
|
class KernelConfig:
|
||||||
"""Configuration for kernel selection and warmup behavior."""
|
"""Configuration for kernel selection and warmup behavior."""
|
||||||
|
|
||||||
|
ir_op_priority: IrOpPriorityConfig = Field(default_factory=IrOpPriorityConfig)
|
||||||
|
"""
|
||||||
|
vLLM IR op priority for dispatching/lowering during the forward pass.
|
||||||
|
Platform defaults appended automatically during VllmConfig.__post_init__.
|
||||||
|
"""
|
||||||
|
|
||||||
enable_flashinfer_autotune: bool = None # type: ignore[assignment]
|
enable_flashinfer_autotune: bool = None # type: ignore[assignment]
|
||||||
"""If True, run FlashInfer autotuning during kernel warmup."""
|
"""If True, run FlashInfer autotuning during kernel warmup."""
|
||||||
|
|
||||||
@@ -51,21 +150,17 @@ class KernelConfig:
|
|||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
Produces a hash unique to the pass configuration.
|
||||||
ensure that it is included in the factors list if
|
Any new fields that affect compilation should be added to the hash.
|
||||||
it affects the computation graph.
|
Any future fields that don't affect compilation should be excluded.
|
||||||
|
|
||||||
Provide a hash that uniquely identifies all the configs
|
|
||||||
that affect the structure of the computation
|
|
||||||
graph from input ids/embeddings to the final hidden states,
|
|
||||||
excluding anything before input ids/embeddings and after
|
|
||||||
the final hidden states.
|
|
||||||
"""
|
"""
|
||||||
# no factors to consider.
|
ignored_factors = {
|
||||||
# this config will not affect the computation graph.
|
"enable_flashinfer_autotune",
|
||||||
factors: list[Any] = []
|
"ir_op_priority", # handled separately below
|
||||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
}
|
||||||
return hash_str
|
factors = get_hash_factors(self, ignored_factors)
|
||||||
|
factors["ir_op_priority"] = self.ir_op_priority.compute_hash()
|
||||||
|
return hash_factors(factors)
|
||||||
|
|
||||||
@field_validator("enable_flashinfer_autotune", mode="wrap")
|
@field_validator("enable_flashinfer_autotune", mode="wrap")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -74,3 +169,31 @@ class KernelConfig:
|
|||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
return handler(value)
|
return handler(value)
|
||||||
|
|
||||||
|
def set_platform_defaults(self, vllm_config: "VllmConfig") -> None:
|
||||||
|
"""Set platform-specific defaults for the kernel config."""
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
platform_op_priority = current_platform.get_default_ir_op_priority(vllm_config)
|
||||||
|
logger.debug(
|
||||||
|
"Setting platform-specific IR op priority defaults: %s, user-defined: %s",
|
||||||
|
platform_op_priority,
|
||||||
|
self.ir_op_priority,
|
||||||
|
)
|
||||||
|
for op_name, op_priority in asdict(platform_op_priority).items():
|
||||||
|
current_op_priority: list[str] = getattr(self.ir_op_priority, op_name)
|
||||||
|
if current_op_priority is None:
|
||||||
|
setattr(self.ir_op_priority, op_name, op_priority)
|
||||||
|
else:
|
||||||
|
# Append platform-specific priorities
|
||||||
|
# Must be idempotent because vllm_config.set_platform_defaults() may be
|
||||||
|
# called multiple times (due to VllmConfig.__post_init__ manual call).
|
||||||
|
unique_op_priority = [
|
||||||
|
op for op in op_priority if op not in current_op_priority
|
||||||
|
]
|
||||||
|
current_op_priority.extend(unique_op_priority)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Final IR op priority after setting platform defaults: %s",
|
||||||
|
self.ir_op_priority,
|
||||||
|
)
|
||||||
|
|||||||
@@ -95,9 +95,11 @@ def enable_norm_fusion(cfg: "VllmConfig") -> bool:
|
|||||||
"""Enable if either RMS norm or quant FP8 custom op is active;
|
"""Enable if either RMS norm or quant FP8 custom op is active;
|
||||||
otherwise Inductor handles fusion."""
|
otherwise Inductor handles fusion."""
|
||||||
|
|
||||||
return cfg.compilation_config.is_custom_op_enabled(
|
return (
|
||||||
"rms_norm"
|
cfg.compilation_config.is_custom_op_enabled("rms_norm")
|
||||||
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
||||||
|
or cfg.kernel_config.ir_op_priority.rms_norm[0] != "native"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def enable_act_fusion(cfg: "VllmConfig") -> bool:
|
def enable_act_fusion(cfg: "VllmConfig") -> bool:
|
||||||
@@ -417,6 +419,10 @@ class VllmConfig:
|
|||||||
vllm_factors.append(self.compilation_config.compute_hash())
|
vllm_factors.append(self.compilation_config.compute_hash())
|
||||||
else:
|
else:
|
||||||
vllm_factors.append("None")
|
vllm_factors.append("None")
|
||||||
|
if self.kernel_config:
|
||||||
|
vllm_factors.append(self.kernel_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append(None)
|
||||||
if self.kv_transfer_config:
|
if self.kv_transfer_config:
|
||||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||||
else:
|
else:
|
||||||
@@ -890,6 +896,13 @@ class VllmConfig:
|
|||||||
else:
|
else:
|
||||||
self.compilation_config.mode = CompilationMode.NONE
|
self.compilation_config.mode = CompilationMode.NONE
|
||||||
|
|
||||||
|
# By default, enable torch wrapping only when using custom Inductor lowering
|
||||||
|
if self.compilation_config.ir_enable_torch_wrap is None:
|
||||||
|
self.compilation_config.ir_enable_torch_wrap = (
|
||||||
|
self.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
and self.compilation_config.backend == "inductor"
|
||||||
|
)
|
||||||
|
|
||||||
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
|
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
|
||||||
if (
|
if (
|
||||||
self.compilation_config.backend == "inductor"
|
self.compilation_config.backend == "inductor"
|
||||||
@@ -899,6 +912,11 @@ class VllmConfig:
|
|||||||
else:
|
else:
|
||||||
self.compilation_config.custom_ops.append("all")
|
self.compilation_config.custom_ops.append("all")
|
||||||
|
|
||||||
|
# This populates IR op priorities,
|
||||||
|
# must happen after compilation mode and backend are decided,
|
||||||
|
# but before fusion defaults are applied as those may depend on op priority.
|
||||||
|
self.kernel_config.set_platform_defaults(self)
|
||||||
|
|
||||||
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
|
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
|
||||||
self._apply_optimization_level_defaults(default_config)
|
self._apply_optimization_level_defaults(default_config)
|
||||||
if self.kernel_config.enable_flashinfer_autotune is None:
|
if self.kernel_config.enable_flashinfer_autotune is None:
|
||||||
@@ -1706,7 +1724,8 @@ class VllmConfig:
|
|||||||
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
||||||
f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa
|
f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa
|
||||||
f"pooler_config={self.model_config.pooler_config!r}, "
|
f"pooler_config={self.model_config.pooler_config!r}, "
|
||||||
f"compilation_config={self.compilation_config!r}"
|
f"compilation_config={self.compilation_config!r}, "
|
||||||
|
f"kernel_config={self.kernel_config!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_block_size(self) -> None:
|
def validate_block_size(self) -> None:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import functools
|
|||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import MISSING, dataclass, fields, is_dataclass
|
from dataclasses import MISSING, asdict, dataclass, fields, is_dataclass
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
from types import UnionType
|
from types import UnionType
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -70,7 +70,7 @@ from vllm.config.cache import (
|
|||||||
PrefixCachingHashAlgo,
|
PrefixCachingHashAlgo,
|
||||||
)
|
)
|
||||||
from vllm.config.device import Device
|
from vllm.config.device import Device
|
||||||
from vllm.config.kernel import MoEBackend
|
from vllm.config.kernel import IrOpPriorityConfig, MoEBackend
|
||||||
from vllm.config.lora import MaxLoRARanks
|
from vllm.config.lora import MaxLoRARanks
|
||||||
from vllm.config.model import (
|
from vllm.config.model import (
|
||||||
ConvertOption,
|
ConvertOption,
|
||||||
@@ -401,6 +401,7 @@ class EngineArgs:
|
|||||||
max_cudagraph_capture_size: int | None = get_field(
|
max_cudagraph_capture_size: int | None = get_field(
|
||||||
CompilationConfig, "max_cudagraph_capture_size"
|
CompilationConfig, "max_cudagraph_capture_size"
|
||||||
)
|
)
|
||||||
|
ir_op_priority: IrOpPriorityConfig = get_field(KernelConfig, "ir_op_priority")
|
||||||
# Note: Specifying a custom executor backend by passing a class
|
# Note: Specifying a custom executor backend by passing a class
|
||||||
# is intended for expert use only. The API may change without
|
# is intended for expert use only. The API may change without
|
||||||
# notice.
|
# notice.
|
||||||
@@ -657,6 +658,9 @@ class EngineArgs:
|
|||||||
self.weight_transfer_config = WeightTransferConfig(
|
self.weight_transfer_config = WeightTransferConfig(
|
||||||
**self.weight_transfer_config
|
**self.weight_transfer_config
|
||||||
)
|
)
|
||||||
|
if isinstance(self.ir_op_priority, dict):
|
||||||
|
self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority)
|
||||||
|
|
||||||
# Setup plugins
|
# Setup plugins
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
|
|
||||||
@@ -1293,6 +1297,7 @@ class EngineArgs:
|
|||||||
title="KernelConfig",
|
title="KernelConfig",
|
||||||
description=KernelConfig.__doc__,
|
description=KernelConfig.__doc__,
|
||||||
)
|
)
|
||||||
|
kernel_group.add_argument("--ir-op-priority", **kernel_kwargs["ir_op_priority"])
|
||||||
kernel_group.add_argument(
|
kernel_group.add_argument(
|
||||||
"--enable-flashinfer-autotune",
|
"--enable-flashinfer-autotune",
|
||||||
**kernel_kwargs["enable_flashinfer_autotune"],
|
**kernel_kwargs["enable_flashinfer_autotune"],
|
||||||
@@ -1917,6 +1922,22 @@ class EngineArgs:
|
|||||||
if self.moe_backend != "auto":
|
if self.moe_backend != "auto":
|
||||||
kernel_config.moe_backend = self.moe_backend
|
kernel_config.moe_backend = self.moe_backend
|
||||||
|
|
||||||
|
# Transfer top-level ir_op_priority into KernelConfig.ir_op_priority
|
||||||
|
for op_name, op_priority in asdict(self.ir_op_priority).items():
|
||||||
|
# Empty means unset
|
||||||
|
if not op_priority:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Priority cannot be set 2x for the same op
|
||||||
|
if getattr(kernel_config.ir_op_priority, op_name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Op priority for {op_name} specified via both ir_op_priority "
|
||||||
|
f"and KernelConfig.ir_op_priority, only one allowed at a time."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the attribute
|
||||||
|
setattr(kernel_config.ir_op_priority, op_name, op_priority)
|
||||||
|
|
||||||
load_config = self.create_load_config()
|
load_config = self.create_load_config()
|
||||||
|
|
||||||
# Pass reasoning_parser into StructuredOutputsConfig
|
# Pass reasoning_parser into StructuredOutputsConfig
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Any
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
import vllm.ir
|
||||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -378,7 +379,13 @@ def set_forward_context(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with override_forward_context(forward_context):
|
with (
|
||||||
|
override_forward_context(forward_context),
|
||||||
|
vllm_config.kernel_config.ir_op_priority.set_priority(),
|
||||||
|
vllm.ir.enable_torch_wrap(
|
||||||
|
vllm_config.compilation_config.ir_enable_torch_wrap
|
||||||
|
),
|
||||||
|
):
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
global last_logging_time, batchsize_logging_interval
|
global last_logging_time, batchsize_logging_interval
|
||||||
|
|||||||
6
vllm/ir/__init__.py
Normal file
6
vllm/ir/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from . import ops
|
||||||
|
from .op import enable_torch_wrap, register_op
|
||||||
|
|
||||||
|
__all__ = ["enable_torch_wrap", "register_op", "ops"]
|
||||||
414
vllm/ir/op.py
Normal file
414
vllm/ir/op.py
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import contextlib
|
||||||
|
import inspect
|
||||||
|
from collections.abc import Callable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, ClassVar, overload
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.library import Library, infer_schema
|
||||||
|
|
||||||
|
from vllm.ir.util import hash_source, weak_cache
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.logging_utils import lazy, tensors_str_no_data
|
||||||
|
|
||||||
|
vllm_ir_lib = Library("vllm_ir", "FRAGMENT")
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
RESERVED_PROVIDERS = ["native", "unfused"]
|
||||||
|
"""Providers that are reserved and cannot be used for custom implementations."""
|
||||||
|
|
||||||
|
_ENABLE_TORCH_WRAP: bool = True
|
||||||
|
"""Global override flag to control torch op layer wrapping."""
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def enable_torch_wrap(enable: bool = True):
|
||||||
|
"""
|
||||||
|
Context manager to enable/disable torch custom op wrapping for vLLM IR ops.
|
||||||
|
When torch wrapping is disabled, the torch custom op layer is skipped
|
||||||
|
and IR ops dispatch directly to the implementation.
|
||||||
|
Helpful for avoiding torch dispatch overhead in eager mode
|
||||||
|
and avoiding the need for lowering for platforms not using Inductor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
global _ENABLE_TORCH_WRAP
|
||||||
|
old = _ENABLE_TORCH_WRAP
|
||||||
|
try:
|
||||||
|
_ENABLE_TORCH_WRAP = enable
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_ENABLE_TORCH_WRAP = old
|
||||||
|
|
||||||
|
|
||||||
|
# 0-param decorator overload
|
||||||
|
@overload
|
||||||
|
def register_op(f: Callable[..., Any]) -> "IrOp": ...
|
||||||
|
|
||||||
|
|
||||||
|
# parametrized decorator overload
|
||||||
|
@overload
|
||||||
|
def register_op(
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> Callable[[Callable[..., Any]], "IrOp"]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def register_op(
|
||||||
|
f: Callable | None = None,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> "IrOp | Callable[[Callable], IrOp]":
|
||||||
|
"""
|
||||||
|
Register a new vLLM IR op.
|
||||||
|
|
||||||
|
:param f: the native implementation of the op
|
||||||
|
:param name: the name of the op, defaults to the function name
|
||||||
|
:return: the IrOp object if f is provided, otherwise a decorator
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
@vllm.ir.register_op
|
||||||
|
def my_op(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
|
@vllm.ir.register_op(name="custom_mul")
|
||||||
|
def multiply(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x * y"""
|
||||||
|
|
||||||
|
def decorator(_f: Callable):
|
||||||
|
op_name: str = _f.__name__ if name is None else name
|
||||||
|
assert op_name not in IrOp.registry
|
||||||
|
op = IrOp(op_name, _f)
|
||||||
|
IrOp.registry[op_name] = op
|
||||||
|
return op
|
||||||
|
|
||||||
|
if f is not None:
|
||||||
|
return decorator(f)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class IrOp:
|
||||||
|
registry: ClassVar[dict[str, "IrOp"]] = {}
|
||||||
|
|
||||||
|
name: str
|
||||||
|
impls: dict[str, "IrOpImpl"]
|
||||||
|
|
||||||
|
def __init__(self, name: str, native_impl: Callable):
|
||||||
|
self._py_signature = inspect.signature(native_impl)
|
||||||
|
if any(
|
||||||
|
p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||||
|
for p in self._py_signature.parameters.values()
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Op {name} has keyword-only arguments which are not currently "
|
||||||
|
f"supported. That's because kwargs are not allowed during lowering."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.name = name
|
||||||
|
self.impls: dict[str, IrOpImpl] = {}
|
||||||
|
self._priority_impls: list[IrOpImpl] = []
|
||||||
|
self._schema_str = infer_schema(native_impl, mutates_args=[])
|
||||||
|
|
||||||
|
# native implementation
|
||||||
|
self.impls["native"] = IrOpImpl(
|
||||||
|
self, "native", native_impl, supported=True, supports_args=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# By default, fake routes directly to native,
|
||||||
|
# can be overridden by register_fake
|
||||||
|
self._fake_fn = native_impl
|
||||||
|
|
||||||
|
# torch registration
|
||||||
|
vllm_ir_lib.define(self.name + self._schema_str)
|
||||||
|
# CompositeExplicitAutograd is not decomposed
|
||||||
|
# by ATen IR normalization in AOTAutograd
|
||||||
|
vllm_ir_lib.impl(
|
||||||
|
self.name, self._inner_call, dispatch_key="CompositeExplicitAutograd"
|
||||||
|
)
|
||||||
|
vllm_ir_lib._register_fake(self.name, self._fake_call)
|
||||||
|
assert hasattr(torch.ops.vllm_ir, name)
|
||||||
|
self.torch_op: torch._ops.OpOverload = getattr(torch.ops.vllm_ir, name).default
|
||||||
|
|
||||||
|
def register_fake(self, fn: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Register a fake impl for the torch custom op. If this method is not called,
|
||||||
|
the native implementation is used directly for the fake implementation.
|
||||||
|
"""
|
||||||
|
self._fake_fn = fn
|
||||||
|
return fn
|
||||||
|
|
||||||
|
def _fake_call(self, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Call to the fake implementation of the op. We use indirection because we want
|
||||||
|
users to be able to register fake later but also want it to fall back to native
|
||||||
|
directly by default, instead of going through the dispatching mechanism.
|
||||||
|
"""
|
||||||
|
return self._fake_fn(*args, **kwargs)
|
||||||
|
|
||||||
|
def register_impl(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
*,
|
||||||
|
supported: bool = True,
|
||||||
|
supports_args: Callable[..., bool] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Register an implementation for this custom op.
|
||||||
|
:param provider: The name of the provider, must be unique.
|
||||||
|
:param supported: Static support check, use this to check platform support.
|
||||||
|
:param supports_args: Dynamic arg support check, used for types and shapes.
|
||||||
|
:return: A decorator that registers the implementation.
|
||||||
|
|
||||||
|
The decorated function must have the same semantics and signature as
|
||||||
|
the native implementation.
|
||||||
|
|
||||||
|
The provider name must be unique and not one of the RESERVED_PROVIDERS.
|
||||||
|
The supported and supports_args parameters should not be used to implement
|
||||||
|
custom enablement logic based on global state (e.g. environment variables).
|
||||||
|
Instead, supported param should only be used to check for platform support
|
||||||
|
(e.g. whether a specific hardware or library is available).
|
||||||
|
supports_args should be used to check whether the provided arguments are
|
||||||
|
compatible with the implementation.
|
||||||
|
For custom enablement logic, set op impl priority.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@my_op.register_impl("my_provider", supported=torch.cuda.is_available())
|
||||||
|
def my_provider_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert provider not in RESERVED_PROVIDERS, (
|
||||||
|
f"Provider name {provider} is reserved."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _register_impl(f: Callable):
|
||||||
|
impl = IrOpImpl(self, provider, f, supported, supports_args)
|
||||||
|
self.impls[provider] = impl
|
||||||
|
|
||||||
|
if self.get_priority():
|
||||||
|
logger.warning(
|
||||||
|
"Warning: registering new impl %s for op %s while priority is set.",
|
||||||
|
provider,
|
||||||
|
self.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return impl
|
||||||
|
|
||||||
|
return _register_impl
|
||||||
|
|
||||||
|
def _inner_call(self, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Eager call to torch op lands here. When torch wrapping is disabled,
|
||||||
|
__call__ routes straight here instead of going through torch op dispatching.
|
||||||
|
"""
|
||||||
|
impl = self.dispatch(*args, **kwargs)
|
||||||
|
return impl.impl_fn(*args, **kwargs)
|
||||||
|
|
||||||
|
def apply_arg_defaults(self, args) -> tuple:
|
||||||
|
"""
|
||||||
|
Return args with default values applied.
|
||||||
|
Defaults are taken from the native implementation signature.
|
||||||
|
|
||||||
|
SHOULD NOT BE USED IN THE DISPATCH PATH (SLOW).
|
||||||
|
Only for Inductor lowering.
|
||||||
|
"""
|
||||||
|
bound_args = self._py_signature.bind(*args)
|
||||||
|
bound_args.apply_defaults()
|
||||||
|
return bound_args.args
|
||||||
|
|
||||||
|
def dispatch(self, *args, **kwargs) -> "IrOpImpl":
|
||||||
|
"""
|
||||||
|
Dispatch to the appropriate implementation based on current priority
|
||||||
|
and argument support checks. Returns the selected IrOpImpl.
|
||||||
|
|
||||||
|
THIS FUNCTION IS ON THE HOT PATH (OP DISPATCH), MUST BE FAST.
|
||||||
|
"""
|
||||||
|
if not self._priority_impls:
|
||||||
|
if not torch.compiler.is_compiling():
|
||||||
|
# Logging not compatible with Dynamo tracing
|
||||||
|
# (this code is exposed when torch wrapping is disabled)
|
||||||
|
logger.warning_once(
|
||||||
|
"Priority not set for op %s, using native implementation.",
|
||||||
|
self.name,
|
||||||
|
)
|
||||||
|
return self.impls["native"]
|
||||||
|
|
||||||
|
for impl in self._priority_impls:
|
||||||
|
if not impl.supported:
|
||||||
|
raise ValueError(
|
||||||
|
f"Implementation {impl.provider} for op {self.name} not supported. "
|
||||||
|
f"All implementations in priority list must be supported."
|
||||||
|
)
|
||||||
|
if impl.supports_args(*args, **kwargs):
|
||||||
|
return impl
|
||||||
|
|
||||||
|
if not torch.compiler.is_compiling():
|
||||||
|
logger.debug(
|
||||||
|
"Skipping provider %s because it does not support "
|
||||||
|
"%s with args=%s kwargs=%s",
|
||||||
|
impl.provider,
|
||||||
|
self.name,
|
||||||
|
lazy(lambda: tensors_str_no_data(args)),
|
||||||
|
lazy(lambda: tensors_str_no_data(kwargs)),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
"Priority set incorrectly: the last implementation must "
|
||||||
|
"support all args (can be native). This is likely an internal bug"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
if not _ENABLE_TORCH_WRAP:
|
||||||
|
return self._inner_call(*args, **kwargs)
|
||||||
|
|
||||||
|
return self.torch_op(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_priority(self) -> list[str]:
|
||||||
|
"""Get the current dispatch priority for implementations for this op."""
|
||||||
|
return [p.provider for p in self._priority_impls]
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_priority(self, priority: list[str]):
|
||||||
|
"""
|
||||||
|
Context manager to set the dispatch priority for implementations for this op.
|
||||||
|
"""
|
||||||
|
assert all(p in self.impls for p in priority), (
|
||||||
|
"All providers in priority must be registered implementations."
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter_priority_impls(p_list: list[str]) -> list[IrOpImpl]:
|
||||||
|
filtered_impls = []
|
||||||
|
for p in p_list:
|
||||||
|
impl = self.impls[p]
|
||||||
|
if not impl.supported:
|
||||||
|
# Skip unsupported implementations
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered_impls.append(impl)
|
||||||
|
|
||||||
|
# If all args are supported, skip other implementations
|
||||||
|
if impl.supports_all_args:
|
||||||
|
return filtered_impls
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
"Op %s: No implementation in priority list supports all args, "
|
||||||
|
"execution fallback to native is possible. To silence this warning, "
|
||||||
|
"explicitly add 'native' to the end of the priority list",
|
||||||
|
self.name,
|
||||||
|
)
|
||||||
|
filtered_impls.append(self.impls["native"])
|
||||||
|
return filtered_impls
|
||||||
|
|
||||||
|
# Temporarily set priority
|
||||||
|
old_priority_impls = self._priority_impls
|
||||||
|
try:
|
||||||
|
self._priority_impls = filter_priority_impls(priority)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._priority_impls = old_priority_impls
|
||||||
|
|
||||||
|
def supported_providers(self) -> list[str]:
|
||||||
|
return [p.provider for p in self.impls.values() if p.supported]
|
||||||
|
|
||||||
|
|
||||||
|
class IrOpImpl:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
op: IrOp,
|
||||||
|
provider: str,
|
||||||
|
impl_fn: Callable,
|
||||||
|
supported: bool,
|
||||||
|
supports_args: Callable[..., bool] | None,
|
||||||
|
):
|
||||||
|
assert provider not in op.impls, (
|
||||||
|
f"Implementation for provider {provider} already registered."
|
||||||
|
)
|
||||||
|
# Native also uses this path, so we allow it here.
|
||||||
|
assert provider == "native" or provider not in RESERVED_PROVIDERS
|
||||||
|
|
||||||
|
# Enforce the exact same schema as the native implementation.
|
||||||
|
# This takes care of names, types, and defaults.
|
||||||
|
schema = infer_schema(impl_fn, mutates_args=[])
|
||||||
|
if schema != op._schema_str:
|
||||||
|
raise ValueError(
|
||||||
|
f"Implementation for provider {provider} has schema '{schema}' which "
|
||||||
|
f"does not match native schema '{op._schema_str}' for op {op.name}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if supports_args is not None:
|
||||||
|
if not callable(supports_args):
|
||||||
|
raise ValueError(
|
||||||
|
f"supports_args for provider {provider} must be a callable"
|
||||||
|
)
|
||||||
|
|
||||||
|
# We also manually validate the supports_args signature.
|
||||||
|
# Matching signatures allow faster dispatch on the hotpath.
|
||||||
|
|
||||||
|
# Check that supports_args does not have keyword-only parameters
|
||||||
|
supports_args_signature = inspect.signature(supports_args)
|
||||||
|
params = supports_args_signature.parameters
|
||||||
|
if any(p.kind == inspect.Parameter.KEYWORD_ONLY for p in params.values()):
|
||||||
|
raise ValueError(
|
||||||
|
f"supports_args for provider {provider} "
|
||||||
|
f"cannot have keyword-only parameters"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that supports_args has the same total number of parameters
|
||||||
|
op_params = op._py_signature.parameters
|
||||||
|
if len(params) != len(op_params):
|
||||||
|
raise ValueError(
|
||||||
|
f"supports_args for provider {provider} must have the same number "
|
||||||
|
f"of parameters ({len(params)}) as the native implementation "
|
||||||
|
f"({len(op_params)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that names and defaults match for supports_args
|
||||||
|
for p, op_p in zip(params.values(), op_params.values()):
|
||||||
|
if p.name != op_p.name:
|
||||||
|
raise ValueError(
|
||||||
|
f"supports_args for provider {provider} has parameter "
|
||||||
|
f"'{p.name}' which does not match native parameter "
|
||||||
|
f"'{op_p.name}'"
|
||||||
|
)
|
||||||
|
if p.default != op_p.default:
|
||||||
|
raise ValueError(
|
||||||
|
f"supports_args for provider {provider} has parameter "
|
||||||
|
f"'{p.name}' with default {p.default} which does not match "
|
||||||
|
f"native default {op_p.default}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.op = op
|
||||||
|
self.provider = provider
|
||||||
|
self.impl_fn = impl_fn
|
||||||
|
self.supported = supported
|
||||||
|
self._supports_args = supports_args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_all_args(self) -> bool:
|
||||||
|
"""Check if this implementation supports all args unconditionally."""
|
||||||
|
return self._supports_args is None
|
||||||
|
|
||||||
|
def supports_args(self, *args, **kwargs) -> bool:
|
||||||
|
if self._supports_args is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return self._supports_args(*args, **kwargs)
|
||||||
|
|
||||||
|
@weak_cache
|
||||||
|
def uuid(self):
|
||||||
|
"""
|
||||||
|
Compile-time hash to uniquely determine whether the implementation has changed.
|
||||||
|
Used by vllm-compile hash mechanism and torch.compile lowering pass uuid to
|
||||||
|
control the vLLM compile cache and AOTAutograd/Inductor caches respectively.
|
||||||
|
|
||||||
|
Source file contents do not change so we cache uuid.
|
||||||
|
TODO(luka): Cache the file hash as multiple impls are likely in the same file.
|
||||||
|
"""
|
||||||
|
sources = [Path(inspect.getfile(self.impl_fn))]
|
||||||
|
return hash_source(*sources)
|
||||||
5
vllm/ir/ops/__init__.py
Normal file
5
vllm/ir/ops/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from .layernorm import rms_norm
|
||||||
|
|
||||||
|
__all__ = ["rms_norm"]
|
||||||
22
vllm/ir/ops/layernorm.py
Normal file
22
vllm/ir/ops/layernorm.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from ..op import register_op
|
||||||
|
|
||||||
|
|
||||||
|
@register_op
|
||||||
|
def rms_norm(
|
||||||
|
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
|
||||||
|
) -> Tensor:
|
||||||
|
"""Weighted root-mean-square layer normalization"""
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
x_var = x if variance_size is None else x[..., :variance_size]
|
||||||
|
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + epsilon)
|
||||||
|
x = x.to(orig_dtype)
|
||||||
|
if weight is not None:
|
||||||
|
x = x * weight
|
||||||
|
return x
|
||||||
61
vllm/ir/util.py
Normal file
61
vllm/ir/util.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
import hashlib
|
||||||
|
import inspect
|
||||||
|
import types
|
||||||
|
import weakref
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def hash_source(*srcs: str | Any) -> str:
|
||||||
|
"""
|
||||||
|
Utility method to hash the sources of functions or objects.
|
||||||
|
:param srcs: strings or objects to add to the hash.
|
||||||
|
Objects and functions have their source inspected.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
hasher = hashlib.sha256()
|
||||||
|
for src in srcs:
|
||||||
|
if src is None:
|
||||||
|
src_str = "None"
|
||||||
|
elif isinstance(src, str):
|
||||||
|
src_str = src
|
||||||
|
elif isinstance(src, Path):
|
||||||
|
src_str = src.read_text()
|
||||||
|
elif isinstance(src, (types.FunctionType, type)):
|
||||||
|
src_str = inspect.getsource(src)
|
||||||
|
else:
|
||||||
|
# object instance
|
||||||
|
src_str = inspect.getsource(src.__class__)
|
||||||
|
hasher.update(src_str.encode("utf-8"))
|
||||||
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def weak_lru_cache(maxsize: int | None = 128, typed: bool = False):
|
||||||
|
"""
|
||||||
|
LRU Cache decorator that keeps a weak reference to 'self'.
|
||||||
|
This avoids memory leakage, which happens when functools.lru_cache
|
||||||
|
stores a reference to self in the global cache.
|
||||||
|
|
||||||
|
Taken from: https://stackoverflow.com/a/68052994/5082708
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(func):
|
||||||
|
@functools.lru_cache(maxsize, typed)
|
||||||
|
def _func(_self, *args, **kwargs):
|
||||||
|
return func(_self(), *args, **kwargs)
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def inner(self, *args, **kwargs):
|
||||||
|
return _func(weakref.ref(self), *args, **kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def weak_cache(user_function, /):
|
||||||
|
"""Simple weak equivalent to functools.cache"""
|
||||||
|
return weak_lru_cache(maxsize=None)(user_function)
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Kernel implementations for vLLM."""
|
"""Kernel implementations for vLLM."""
|
||||||
|
|
||||||
|
from . import aiter_ops, oink_ops, vllm_c, xpu_ops
|
||||||
|
|
||||||
|
__all__ = ["vllm_c", "aiter_ops", "oink_ops", "xpu_ops"]
|
||||||
|
|||||||
79
vllm/kernels/aiter_ops.py
Normal file
79
vllm/kernels/aiter_ops.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.library import Library
|
||||||
|
|
||||||
|
from vllm import ir
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
current_platform.import_kernels()
|
||||||
|
|
||||||
|
|
||||||
|
def is_aiter_found() -> bool:
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
return find_spec("aiter") is not None
|
||||||
|
|
||||||
|
|
||||||
|
aiter_lib = Library("vllm_aiter", "FRAGMENT")
|
||||||
|
"""
|
||||||
|
This library holds torch custom ops for wrapped AITER ops.
|
||||||
|
Many AITER ops want to remain invisible to torch.compile even after lowering.
|
||||||
|
They are thus wrapped into torch custom ops inside the IR op implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
direct_register_aiter_op = functools.partial(
|
||||||
|
direct_register_custom_op, target_lib=aiter_lib
|
||||||
|
)
|
||||||
|
"""Syntactic sugar for registering AITER custom ops."""
|
||||||
|
|
||||||
|
AITER_SUPPORTED = is_aiter_found()
|
||||||
|
"""Most kernels in this file are supported if AITER is installed."""
|
||||||
|
|
||||||
|
rms_no_var_16bit_only = (
|
||||||
|
lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||||
|
and x.dtype
|
||||||
|
in (
|
||||||
|
torch.float16,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override."""
|
||||||
|
|
||||||
|
|
||||||
|
@ir.ops.rms_norm.register_impl(
|
||||||
|
"aiter", supports_args=rms_no_var_16bit_only, supported=AITER_SUPPORTED
|
||||||
|
)
|
||||||
|
def rms_norm(
|
||||||
|
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
|
||||||
|
) -> Tensor:
|
||||||
|
assert variance_size is None
|
||||||
|
assert x.dtype in (torch.float16, torch.bfloat16)
|
||||||
|
if weight is None:
|
||||||
|
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
|
||||||
|
return torch.ops.vllm_aiter.rms_norm(x, weight, epsilon)
|
||||||
|
|
||||||
|
|
||||||
|
def _rms_norm_impl(x: Tensor, weight: Tensor, variance_epsilon: float) -> Tensor:
|
||||||
|
from aiter import rms_norm
|
||||||
|
|
||||||
|
if x.dim() > 2:
|
||||||
|
x_original_shape = x.shape
|
||||||
|
x = x.reshape(-1, x_original_shape[-1])
|
||||||
|
x = rms_norm(x, weight, variance_epsilon)
|
||||||
|
return x.reshape(x_original_shape)
|
||||||
|
|
||||||
|
return rms_norm(x, weight, variance_epsilon)
|
||||||
|
|
||||||
|
|
||||||
|
def _rms_norm_fake(x: Tensor, weight: Tensor, variance_epsilon: float) -> Tensor:
|
||||||
|
return torch.empty_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_aiter_op(
|
||||||
|
op_name="rms_norm", op_func=_rms_norm_impl, fake_impl=_rms_norm_fake
|
||||||
|
)
|
||||||
77
vllm/kernels/oink_ops.py
Normal file
77
vllm/kernels/oink_ops.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import ir
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
OINK_AVAILABLE = current_platform.has_device_capability(100) and hasattr(
|
||||||
|
torch.ops, "oink"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def has_oink_op(name: str) -> bool:
|
||||||
|
"""Check if a specific oink op is registered."""
|
||||||
|
return OINK_AVAILABLE and hasattr(torch.ops.oink, name)
|
||||||
|
|
||||||
|
|
||||||
|
def _can_view_as_2d(x: torch.Tensor) -> bool:
|
||||||
|
"""Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
|
||||||
|
if x.dim() < 2:
|
||||||
|
return False
|
||||||
|
if x.dim() == 2:
|
||||||
|
return True
|
||||||
|
# For a view(-1, N) to be valid, all leading dims must be contiguous with
|
||||||
|
# respect to each other (size-1 dims are ignored).
|
||||||
|
for dim in range(x.dim() - 1):
|
||||||
|
# Strides for size-1 dims are irrelevant and can be arbitrary.
|
||||||
|
if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
|
||||||
|
dim + 1
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
|
||||||
|
"""Return True if x_2d meets Oink's pointer-path stride constraints."""
|
||||||
|
if x_2d.dim() != 2:
|
||||||
|
return False
|
||||||
|
if x_2d.stride(1) != 1:
|
||||||
|
return False
|
||||||
|
# Match Oink's vectorization constraint: stride(0) divisible by 256b.
|
||||||
|
if x_2d.dtype in (torch.float16, torch.bfloat16):
|
||||||
|
divby = 16
|
||||||
|
elif x_2d.dtype == torch.float32:
|
||||||
|
divby = 8
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return (x_2d.stride(0) % divby) == 0
|
||||||
|
|
||||||
|
|
||||||
|
oink_rms_supported = (
|
||||||
|
lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||||
|
and weight is not None
|
||||||
|
and x.dim() >= 2
|
||||||
|
and x.dtype == weight.dtype
|
||||||
|
and weight.is_contiguous()
|
||||||
|
and _can_view_as_2d(x)
|
||||||
|
and _is_oink_stride_compatible_2d(x.view(-1, x.shape[-1]))
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
Oink rms only supports 2d-like inputs with contiguous weight
|
||||||
|
and no variance_size override.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ir.ops.rms_norm.register_impl(
|
||||||
|
"oink", supports_args=oink_rms_supported, supported=has_oink_op("rmsnorm")
|
||||||
|
)
|
||||||
|
def rms_norm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor | None,
|
||||||
|
epsilon: float,
|
||||||
|
variance_size: int | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert variance_size is None
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
return torch.ops.oink.rmsnorm(x_2d, weight, epsilon).view_as(x)
|
||||||
30
vllm/kernels/vllm_c.py
Normal file
30
vllm/kernels/vllm_c.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from vllm import ir
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
current_platform.import_kernels()
|
||||||
|
|
||||||
|
CUDA_ALIKE = current_platform.is_cuda_alike()
|
||||||
|
"""Most kernels in this file are supported on all CUDA-alike platforms."""
|
||||||
|
|
||||||
|
rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||||
|
"""vLLM kernel does not support variance_size parameter."""
|
||||||
|
|
||||||
|
|
||||||
|
@ir.ops.rms_norm.register_impl(
|
||||||
|
"vllm_c", supports_args=rms_no_var_size, supported=CUDA_ALIKE
|
||||||
|
)
|
||||||
|
def rms_norm(
|
||||||
|
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
|
||||||
|
) -> Tensor:
|
||||||
|
if weight is None:
|
||||||
|
# Kernel requires weight tensor, pass ones
|
||||||
|
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
|
||||||
|
assert variance_size is None
|
||||||
|
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
|
||||||
|
torch.ops._C.rms_norm(output, x, weight, epsilon)
|
||||||
|
return output
|
||||||
36
vllm/kernels/xpu_ops.py
Normal file
36
vllm/kernels/xpu_ops.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from vllm import ir
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
current_platform.import_kernels()
|
||||||
|
|
||||||
|
|
||||||
|
def is_xpu_kernels_found() -> bool:
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
return find_spec("vllm_xpu_kernels") is not None
|
||||||
|
|
||||||
|
|
||||||
|
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
|
||||||
|
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
|
||||||
|
|
||||||
|
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||||
|
|
||||||
|
|
||||||
|
@ir.ops.rms_norm.register_impl(
|
||||||
|
"xpu_kernels", supports_args=rms_no_var, supported=XPU_KERNELS_SUPPORTED
|
||||||
|
)
|
||||||
|
def rms_norm(
|
||||||
|
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
|
||||||
|
) -> Tensor:
|
||||||
|
if weight is None:
|
||||||
|
# Kernel requires weight tensor, pass ones
|
||||||
|
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
|
||||||
|
assert variance_size is None
|
||||||
|
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
|
||||||
|
torch.ops._C.rms_norm(output, x, weight, epsilon)
|
||||||
|
return output
|
||||||
@@ -8,6 +8,7 @@ from vllm.logging_utils.access_log_filter import (
|
|||||||
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
|
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
|
||||||
from vllm.logging_utils.lazy import lazy
|
from vllm.logging_utils.lazy import lazy
|
||||||
from vllm.logging_utils.log_time import logtime
|
from vllm.logging_utils.log_time import logtime
|
||||||
|
from vllm.logging_utils.torch_tensor import tensors_str_no_data
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"NewLineFormatter",
|
"NewLineFormatter",
|
||||||
@@ -16,4 +17,5 @@ __all__ = [
|
|||||||
"create_uvicorn_log_config",
|
"create_uvicorn_log_config",
|
||||||
"lazy",
|
"lazy",
|
||||||
"logtime",
|
"logtime",
|
||||||
|
"tensors_str_no_data",
|
||||||
]
|
]
|
||||||
|
|||||||
10
vllm/logging_utils/torch_tensor.py
Normal file
10
vllm/logging_utils/torch_tensor.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def tensors_str_no_data(arg: Any):
|
||||||
|
from torch._tensor_str import printoptions
|
||||||
|
|
||||||
|
with printoptions(threshold=1, edgeitems=0):
|
||||||
|
return str(arg)
|
||||||
@@ -6,7 +6,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm import _oink_ops, envs
|
# Import kernels
|
||||||
|
import vllm.kernels # noqa: F401
|
||||||
|
from vllm import _oink_ops, envs, ir
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
@@ -51,23 +53,6 @@ def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
|
|||||||
return (x_2d.stride(0) % divby) == 0
|
return (x_2d.stride(0) % divby) == 0
|
||||||
|
|
||||||
|
|
||||||
def rms_norm(
|
|
||||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
|
||||||
) -> torch.Tensor:
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
if envs.VLLM_BATCH_INVARIANT:
|
|
||||||
return rms_norm_batch_invariant(x, weight, variance_epsilon)
|
|
||||||
out = torch.empty_like(x)
|
|
||||||
ops.rms_norm(
|
|
||||||
out,
|
|
||||||
x,
|
|
||||||
weight,
|
|
||||||
variance_epsilon,
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def fused_add_rms_norm(
|
def fused_add_rms_norm(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
@@ -105,23 +90,16 @@ def poly_norm(
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def dispatch_rocm_rmsnorm_func(
|
def dispatch_rocm_rmsnorm_func(dtype: torch.dtype, use_aiter: bool = False):
|
||||||
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
|
|
||||||
):
|
|
||||||
use_aiter = use_aiter and dtype in [
|
use_aiter = use_aiter and dtype in [
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
]
|
]
|
||||||
|
|
||||||
if use_aiter and with_fused_add:
|
|
||||||
return rocm_aiter_ops.rms_norm2d_with_add
|
|
||||||
if use_aiter:
|
if use_aiter:
|
||||||
return rocm_aiter_ops.rms_norm
|
return rocm_aiter_ops.rms_norm2d_with_add
|
||||||
|
else:
|
||||||
# fall back to CUDA implementation
|
|
||||||
if with_fused_add:
|
|
||||||
return fused_add_rms_norm
|
return fused_add_rms_norm
|
||||||
return rms_norm
|
|
||||||
|
|
||||||
|
|
||||||
# --8<-- [start:rms_norm]
|
# --8<-- [start:rms_norm]
|
||||||
@@ -158,20 +136,14 @@ class RMSNorm(CustomOp):
|
|||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
|
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||||
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
|
||||||
with_fused_add=False,
|
|
||||||
dtype=weight_dtype,
|
|
||||||
use_aiter=aiter_rmsnorm_enabled,
|
|
||||||
)
|
|
||||||
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
|
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
|
||||||
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
|
dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional: enable Oink Blackwell RMSNorm custom-op fast path on
|
# Optional: enable Oink Blackwell RMSNorm custom-op fast path on
|
||||||
# compatible CUDA devices (e.g., SM100) when the external Oink
|
# compatible CUDA devices (e.g., SM100) when the external Oink
|
||||||
# package is available. This is detected once at construction time
|
# package is available. This is detected once at construction time
|
||||||
# to avoid per-call device queries in the hot path.
|
# to avoid per-call device queries in the hot path.
|
||||||
self._use_oink_rmsnorm = False
|
|
||||||
self._use_oink_fused_add_rmsnorm = False
|
self._use_oink_fused_add_rmsnorm = False
|
||||||
if (
|
if (
|
||||||
not current_platform.is_rocm()
|
not current_platform.is_rocm()
|
||||||
@@ -203,7 +175,6 @@ class RMSNorm(CustomOp):
|
|||||||
try:
|
try:
|
||||||
device_index = torch.accelerator.current_device_index()
|
device_index = torch.accelerator.current_device_index()
|
||||||
if _oink_ops.is_oink_available_for_device(device_index):
|
if _oink_ops.is_oink_available_for_device(device_index):
|
||||||
self._use_oink_rmsnorm = True
|
|
||||||
self._use_oink_fused_add_rmsnorm = (
|
self._use_oink_fused_add_rmsnorm = (
|
||||||
_oink_ops.has_fused_add_rms_norm()
|
_oink_ops.has_fused_add_rms_norm()
|
||||||
)
|
)
|
||||||
@@ -215,7 +186,6 @@ class RMSNorm(CustomOp):
|
|||||||
"RMSNorm; falling back to vLLM RMSNorm. Error: %s",
|
"RMSNorm; falling back to vLLM RMSNorm. Error: %s",
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
self._use_oink_rmsnorm = False
|
|
||||||
self._use_oink_fused_add_rmsnorm = False
|
self._use_oink_fused_add_rmsnorm = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -270,6 +240,10 @@ class RMSNorm(CustomOp):
|
|||||||
residual: torch.Tensor | None = None,
|
residual: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
|
if residual is None:
|
||||||
|
return ir.ops.rms_norm(
|
||||||
|
x, self.weight.data, self.variance_epsilon, self.variance_size_override
|
||||||
|
)
|
||||||
|
|
||||||
return self.forward_static(
|
return self.forward_static(
|
||||||
x,
|
x,
|
||||||
@@ -286,35 +260,14 @@ class RMSNorm(CustomOp):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: torch.Tensor | None = None,
|
residual: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if residual is None and not envs.VLLM_BATCH_INVARIANT:
|
||||||
|
return ir.ops.rms_norm(
|
||||||
|
x, self.weight.data, self.variance_epsilon, self.variance_size_override
|
||||||
|
)
|
||||||
|
|
||||||
if self.variance_size_override is not None:
|
if self.variance_size_override is not None:
|
||||||
return self.forward_native(x, residual)
|
return self.forward_native(x, residual)
|
||||||
|
|
||||||
# Optional Oink SM100 fast path (no residual). This path is
|
|
||||||
# torch.compile-friendly via torch.ops.oink.rmsnorm and preserves
|
|
||||||
# 2D layouts (including padded rows) when using the Oink
|
|
||||||
# pointer-based kernel.
|
|
||||||
if (
|
|
||||||
residual is None
|
|
||||||
and getattr(self, "_use_oink_rmsnorm", False)
|
|
||||||
and x.is_cuda
|
|
||||||
and x.dim() >= 2
|
|
||||||
and self.has_weight
|
|
||||||
and not envs.VLLM_BATCH_INVARIANT
|
|
||||||
and self.weight.data.dtype == x.dtype
|
|
||||||
and self.weight.data.is_contiguous()
|
|
||||||
):
|
|
||||||
orig_shape = x.shape
|
|
||||||
hidden_size = orig_shape[-1]
|
|
||||||
if _can_view_as_2d(x):
|
|
||||||
x_2d = x.view(-1, hidden_size)
|
|
||||||
if _is_oink_stride_compatible_2d(x_2d):
|
|
||||||
y_2d = _oink_ops.rmsnorm(
|
|
||||||
x_2d,
|
|
||||||
self.weight.data,
|
|
||||||
self.variance_epsilon,
|
|
||||||
)
|
|
||||||
return y_2d.view(orig_shape)
|
|
||||||
|
|
||||||
# Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
|
# Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
|
||||||
# This mirrors vLLM's fused_add_rms_norm semantics by mutating both
|
# This mirrors vLLM's fused_add_rms_norm semantics by mutating both
|
||||||
# `x` (normalized output) and `residual` (residual-out buffer).
|
# `x` (normalized output) and `residual` (residual-out buffer).
|
||||||
@@ -356,29 +309,34 @@ class RMSNorm(CustomOp):
|
|||||||
)
|
)
|
||||||
return x, residual
|
return x, residual
|
||||||
|
|
||||||
add_residual = residual is not None
|
if residual is not None:
|
||||||
if add_residual:
|
|
||||||
return fused_add_rms_norm(
|
return fused_add_rms_norm(
|
||||||
x, residual, self.weight.data, self.variance_epsilon
|
x, residual, self.weight.data, self.variance_epsilon
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
assert envs.VLLM_BATCH_INVARIANT
|
||||||
|
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
|
||||||
|
|
||||||
def forward_hip(
|
def forward_hip(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: torch.Tensor | None = None,
|
residual: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if residual is None and not envs.VLLM_BATCH_INVARIANT:
|
||||||
|
return ir.ops.rms_norm(
|
||||||
|
x, self.weight.data, self.variance_epsilon, self.variance_size_override
|
||||||
|
)
|
||||||
|
|
||||||
if self.variance_size_override is not None:
|
if self.variance_size_override is not None:
|
||||||
return self.forward_native(x, residual)
|
return self.forward_native(x, residual)
|
||||||
|
|
||||||
add_residual = residual is not None
|
if residual is not None:
|
||||||
if add_residual:
|
|
||||||
return self.rocm_norm_func_with_add(
|
return self.rocm_norm_func_with_add(
|
||||||
x, residual, self.weight.data, self.variance_epsilon
|
x, residual, self.weight.data, self.variance_epsilon
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
|
assert envs.VLLM_BATCH_INVARIANT
|
||||||
|
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
|
||||||
|
|
||||||
def forward_xpu(
|
def forward_xpu(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
from vllm.v1.attention.selector import AttentionSelectorConfig
|
from vllm.v1.attention.selector import AttentionSelectorConfig
|
||||||
else:
|
else:
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
@@ -550,6 +551,26 @@ class CudaPlatformBase(Platform):
|
|||||||
def use_custom_op_collectives(cls) -> bool:
|
def use_custom_op_collectives(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConfig:
|
||||||
|
from vllm.config.compilation import CompilationMode
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
|
|
||||||
|
# Native used by default when compiling,
|
||||||
|
# use vllm_c kernels where available when no codegen
|
||||||
|
cc = vllm_config.compilation_config
|
||||||
|
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
|
||||||
|
default = ["native"] if using_inductor else ["vllm_c", "native"]
|
||||||
|
|
||||||
|
# Use oink if enabled for rms_norm
|
||||||
|
# TODO(Laurawly/luka): remove this env var,
|
||||||
|
# users can just use IR op priority directly
|
||||||
|
rms_norm = default
|
||||||
|
if envs.VLLM_USE_OINK_OPS:
|
||||||
|
rms_norm = ["oink"] + default
|
||||||
|
|
||||||
|
return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|||||||
from torch.distributed import PrefixStore, ProcessGroup
|
from torch.distributed import PrefixStore, ProcessGroup
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
from vllm.inputs import EngineInput
|
from vllm.inputs import EngineInput
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@@ -931,6 +932,16 @@ class Platform:
|
|||||||
"num_compute_units is not implemented for the current platform."
|
"num_compute_units is not implemented for the current platform."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_ir_op_priority(
|
||||||
|
cls, vllm_config: "VllmConfig"
|
||||||
|
) -> "IrOpPriorityConfig":
|
||||||
|
"""Get the default IR op priority for the current platform."""
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
|
|
||||||
|
# Native always used by default. Platforms can override this behavior.
|
||||||
|
return IrOpPriorityConfig.with_default(["native"])
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
from vllm.v1.attention.selector import AttentionSelectorConfig
|
from vllm.v1.attention.selector import AttentionSelectorConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -903,3 +904,32 @@ class RocmPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def use_custom_op_collectives(cls) -> bool:
|
def use_custom_op_collectives(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_ir_op_priority(
|
||||||
|
cls, vllm_config: "VllmConfig"
|
||||||
|
) -> "IrOpPriorityConfig":
|
||||||
|
from vllm.config.compilation import CompilationMode
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
|
|
||||||
|
# Native used by default when compiling,
|
||||||
|
# use vllm_c kernels where available when no codegen
|
||||||
|
# TODO(luka/TJ) use aiter, vllm_c, native by default on ROCm
|
||||||
|
cc = vllm_config.compilation_config
|
||||||
|
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
|
||||||
|
default = ["native"] if using_inductor else ["vllm_c", "native"]
|
||||||
|
|
||||||
|
# This (mostly) preserves previous CustomOp behavior
|
||||||
|
# Necessary on ROCm because it's common that users
|
||||||
|
# enable rms_norm to use the aiter kernel.
|
||||||
|
# TODO(luka/TJ) remove env vars completely
|
||||||
|
if (
|
||||||
|
cc.is_custom_op_enabled("rms_norm")
|
||||||
|
and envs.VLLM_ROCM_USE_AITER
|
||||||
|
and envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||||
|
):
|
||||||
|
rms_norm = ["aiter"] + default
|
||||||
|
else:
|
||||||
|
rms_norm = default
|
||||||
|
|
||||||
|
return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
from vllm.v1.attention.selector import AttentionSelectorConfig
|
from vllm.v1.attention.selector import AttentionSelectorConfig
|
||||||
else:
|
else:
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
@@ -257,6 +258,21 @@ class XPUPlatform(Platform):
|
|||||||
)
|
)
|
||||||
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_ir_op_priority(
|
||||||
|
cls, vllm_config: "VllmConfig"
|
||||||
|
) -> "IrOpPriorityConfig":
|
||||||
|
from vllm.config.compilation import CompilationMode
|
||||||
|
from vllm.config.kernel import IrOpPriorityConfig
|
||||||
|
|
||||||
|
# Native used by default when compiling,
|
||||||
|
# use fused kernels where available when no codegen
|
||||||
|
cc = vllm_config.compilation_config
|
||||||
|
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
|
||||||
|
default = ["native"] if using_inductor else ["xpu_kernels", "native"]
|
||||||
|
|
||||||
|
return IrOpPriorityConfig.with_default(default)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def device_count(cls) -> int:
|
def device_count(cls) -> int:
|
||||||
return torch.xpu.device_count()
|
return torch.xpu.device_count()
|
||||||
|
|||||||
Reference in New Issue
Block a user